diff --git a/egs/ljspeech/TTS/matcha/fbank.py b/egs/ljspeech/TTS/matcha/fbank.py index d729fa425f..cc94a301f2 100644 --- a/egs/ljspeech/TTS/matcha/fbank.py +++ b/egs/ljspeech/TTS/matcha/fbank.py @@ -17,6 +17,7 @@ class MatchaFbankConfig: win_length: int f_min: float f_max: float + device: str = "cuda" @register_extractor @@ -46,7 +47,7 @@ def extract( f"Mismatched sampling rate: extractor expects {expected_sr}, " f"got {sampling_rate}" ) - samples = torch.from_numpy(samples) + samples = torch.from_numpy(samples).to(self.device) assert samples.ndim == 2, samples.shape assert samples.shape[0] == 1, samples.shape @@ -81,7 +82,7 @@ def extract( mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate" ).squeeze(0) - return mel.numpy() + return mel.cpu().numpy() @property def frame_shift(self) -> Seconds: diff --git a/egs/wenetspeech4tts/TTS/README.md b/egs/wenetspeech4tts/TTS/README.md index f35bb51c7f..d47687acd4 100644 --- a/egs/wenetspeech4tts/TTS/README.md +++ b/egs/wenetspeech4tts/TTS/README.md @@ -68,5 +68,69 @@ python3 valle/infer.py --output-dir demos_epoch_${epoch}_avg_${avg}_top_p_${top_ --text-extractor pypinyin_initials_finals --top-p ${top_p} ``` +# [F5-TTS](https://arxiv.org/abs/2410.06885) + +./f5-tts contains the code for training F5-TTS model. + +Generated samples and training logs of wenetspeech basic 7k hours data can be found [here](https://huggingface.co/yuekai/f5-tts-small-wenetspeech4tts-basic/tensorboard). + +Preparation: + +``` +bash prepare.sh --stage 5 --stop_stage 6 +``` +(Note: To compatiable with F5-TTS official checkpoint, we direclty use `vocab.txt` from [here.](https://github.com/SWivid/F5-TTS/blob/129014c5b43f135b0100d49a0c6804dd4cf673e1/data/Emilia_ZH_EN_pinyin/vocab.txt) To generate your own `vocab.txt`, you may refer to [the script](https://github.com/SWivid/F5-TTS/blob/main/src/f5_tts/train/datasets/prepare_emilia.py).) + +The training command is given below: + +``` +# docker: ghcr.io/swivid/f5-tts:main +# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html +# pip install kaldialign lhotse tensorboard bigvganinference sentencepiece + +world_size=8 +exp_dir=exp/f5-tts-small +python3 f5-tts/train.py --max-duration 700 --filter-min-duration 0.5 --filter-max-duration 20 \ + --num-buckets 6 --dtype "bfloat16" --save-every-n 5000 --valid-interval 10000 \ + --base-lr 7.5e-5 --warmup-steps 20000 --num-epochs 60 \ + --num-decoder-layers 18 --nhead 12 --decoder-dim 768 \ + --exp-dir ${exp_dir} --world-size ${world_size} +``` + +To inference with Icefall Wenetspeech4TTS trained F5-Small, use: +``` +huggingface-cli login +huggingface-cli download --local-dir seed_tts_eval yuekai/seed_tts_eval --repo-type dataset +huggingface-cli download --local-dir ${exp_dir} yuekai/f5-tts-small-wenetspeech4tts-basic +huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x + +manifest=./seed_tts_eval/seedtts_testset/zh/meta.lst +model_path=f5-tts-small-wenetspeech4tts-basic/epoch-56-avg-14.pt +# skip +python3 f5-tts/generate_averaged_model.py \ + --epoch 56 \ + --avg 14 --decoder-dim 768 --nhead 12 --num-decoder-layers 18 \ + --exp-dir exp/f5_small + + +accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18 +bash local/compute_wer.sh $output_dir $manifest +``` + +To inference with official Emilia trained F5-Base, use: +``` +huggingface-cli login +huggingface-cli download --local-dir seed_tts_eval yuekai/seed_tts_eval --repo-type dataset +huggingface-cli download --local-dir F5-TTS SWivid/F5-TTS +huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x + +manifest=./seed_tts_eval/seedtts_testset/zh/meta.lst +model_path=./F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt + +accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir +bash local/compute_wer.sh $output_dir $manifest +``` + # Credits -- [vall-e](https://github.com/lifeiteng/vall-e) +- [VALL-E](https://github.com/lifeiteng/vall-e) +- [F5-TTS](https://github.com/SWivid/F5-TTS) diff --git a/egs/wenetspeech4tts/TTS/f5-tts/generate_averaged_model.py b/egs/wenetspeech4tts/TTS/f5-tts/generate_averaged_model.py new file mode 100644 index 0000000000..f023585535 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/generate_averaged_model.py @@ -0,0 +1,173 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Yifan Yang) +# Copyright 2024 Yuekai Zhang +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) use the checkpoint exp_dir/epoch-xxx.pt +python3 bin/generate_averaged_model.py \ + --epoch 40 \ + --avg 5 \ + --exp-dir ${exp_dir} + +It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`. +You can later load it by `torch.load("epoch-28-avg-15.pt")`. +""" + + +import argparse +from pathlib import Path + +import k2 +import torch +from train import add_model_arguments, get_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, +) +from icefall.utils import AttributeDict + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + add_model_arguments(parser) + return parser + + +@torch.no_grad() +def main(): + parser = get_parser() + + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = AttributeDict() + params.update(vars(args)) + + if params.iter > 0: + params.suffix = f"checkpoint-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + print("Script started") + + device = torch.device("cpu") + print(f"Device: {device}") + + print("About to create model") + filename = f"{params.exp_dir}/epoch-{params.epoch}.pt" + checkpoint = torch.load(filename, map_location=device) + args = AttributeDict(checkpoint) + model = get_model(args) + + if params.iter > 0: + # TODO FIX ME + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + print( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + filename = params.exp_dir / f"checkpoint-{params.iter}-avg-{params.avg}.pt" + torch.save({"model": model.state_dict()}, filename) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + print( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + filenames = [ + f"{params.exp_dir}/epoch-{i}.pt" for i in range(start, params.epoch + 1) + ] + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt" + checkpoint["model"] = model.state_dict() + torch.save(checkpoint, filename) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + print("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/wenetspeech4tts/TTS/f5-tts/infer.py b/egs/wenetspeech4tts/TTS/f5-tts/infer.py new file mode 100644 index 0000000000..02e5f0f4d8 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/infer.py @@ -0,0 +1,364 @@ +#!/usr/bin/env python3 +# Modified from https://github.com/SWivid/F5-TTS/blob/main/src/f5_tts/eval/eval_infer_batch.py +""" +Usage: +# docker: ghcr.io/swivid/f5-tts:main +# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html +# pip install kaldialign lhotse tensorboard bigvganinference sentencepiece sherpa-onnx +# huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x +manifest=/path/seed_tts_eval/seedtts_testset/zh/meta.lst +python3 f5-tts/generate_averaged_model.py \ + --epoch 56 \ + --avg 14 --decoder-dim 768 --nhead 12 --num-decoder-layers 18 \ + --exp-dir exp/f5_small +accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18 +bash local/compute_wer.sh $output_dir $manifest +""" +import argparse +import logging +import math +import os +import random +import time +from pathlib import Path + +import torch +import torch.nn.functional as F +import torchaudio +from accelerate import Accelerator +from bigvganinference import BigVGANInference +from model.cfm import CFM +from model.dit import DiT +from model.modules import MelSpec +from model.utils import convert_char_to_pinyin +from tqdm import tqdm +from train import ( + add_model_arguments, + get_model, + get_tokenizer, + load_F5_TTS_pretrained_checkpoint, +) + +from icefall.checkpoint import load_checkpoint + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--tokens", + type=str, + default="f5-tts/vocab.txt", + help="Path to the unique text tokens file", + ) + + parser.add_argument( + "--model-path", + type=str, + default="/home/yuekaiz/HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", + help="Path to the unique text tokens file", + ) + + parser.add_argument( + "--seed", + type=int, + default=0, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--nfe", + type=int, + default=16, + help="The number of steps for the neural ODE", + ) + + parser.add_argument( + "--manifest-file", + type=str, + default="/path/seed_tts_eval/seedtts_testset/zh/meta.lst", + help="The manifest file in seed_tts_eval format", + ) + + parser.add_argument( + "--output-dir", + type=Path, + default="results", + help="The output directory to save the generated wavs", + ) + + parser.add_argument("-ss", "--swaysampling", default=-1, type=float) + add_model_arguments(parser) + return parser.parse_args() + + +def get_inference_prompt( + metainfo, + speed=1.0, + tokenizer="pinyin", + polyphone=True, + target_sample_rate=24000, + n_fft=1024, + win_length=1024, + n_mel_channels=100, + hop_length=256, + mel_spec_type="bigvgan", + target_rms=0.1, + use_truth_duration=False, + infer_batch_size=1, + num_buckets=200, + min_secs=3, + max_secs=40, +): + prompts_all = [] + + min_tokens = min_secs * target_sample_rate // hop_length + max_tokens = max_secs * target_sample_rate // hop_length + + batch_accum = [0] * num_buckets + utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = ( + [[] for _ in range(num_buckets)] for _ in range(6) + ) + + mel_spectrogram = MelSpec( + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + n_mel_channels=n_mel_channels, + target_sample_rate=target_sample_rate, + mel_spec_type=mel_spec_type, + ) + + for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm( + metainfo, desc="Processing prompts..." + ): + # Audio + ref_audio, ref_sr = torchaudio.load(prompt_wav) + ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio))) + if ref_rms < target_rms: + ref_audio = ref_audio * target_rms / ref_rms + assert ( + ref_audio.shape[-1] > 5000 + ), f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue." + if ref_sr != target_sample_rate: + resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate) + ref_audio = resampler(ref_audio) + + # Text + if len(prompt_text[-1].encode("utf-8")) == 1: + prompt_text = prompt_text + " " + text = [prompt_text + gt_text] + if tokenizer == "pinyin": + text_list = convert_char_to_pinyin(text, polyphone=polyphone) + else: + text_list = text + + # Duration, mel frame length + ref_mel_len = ref_audio.shape[-1] // hop_length + if use_truth_duration: + gt_audio, gt_sr = torchaudio.load(gt_wav) + if gt_sr != target_sample_rate: + resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate) + gt_audio = resampler(gt_audio) + total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed) + + # # test vocoder resynthesis + # ref_audio = gt_audio + else: + ref_text_len = len(prompt_text.encode("utf-8")) + gen_text_len = len(gt_text.encode("utf-8")) + total_mel_len = ref_mel_len + int( + ref_mel_len / ref_text_len * gen_text_len / speed + ) + + # to mel spectrogram + ref_mel = mel_spectrogram(ref_audio) + ref_mel = ref_mel.squeeze(0) + + # deal with batch + assert infer_batch_size > 0, "infer_batch_size should be greater than 0." + assert ( + min_tokens <= total_mel_len <= max_tokens + ), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]." + bucket_i = math.floor( + (total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets + ) + + utts[bucket_i].append(utt) + ref_rms_list[bucket_i].append(ref_rms) + ref_mels[bucket_i].append(ref_mel) + ref_mel_lens[bucket_i].append(ref_mel_len) + total_mel_lens[bucket_i].append(total_mel_len) + final_text_list[bucket_i].extend(text_list) + + batch_accum[bucket_i] += total_mel_len + + if batch_accum[bucket_i] >= infer_batch_size: + prompts_all.append( + ( + utts[bucket_i], + ref_rms_list[bucket_i], + padded_mel_batch(ref_mels[bucket_i]), + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) + ) + batch_accum[bucket_i] = 0 + ( + utts[bucket_i], + ref_rms_list[bucket_i], + ref_mels[bucket_i], + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) = ( + [], + [], + [], + [], + [], + [], + ) + + # add residual + for bucket_i, bucket_frames in enumerate(batch_accum): + if bucket_frames > 0: + prompts_all.append( + ( + utts[bucket_i], + ref_rms_list[bucket_i], + padded_mel_batch(ref_mels[bucket_i]), + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) + ) + # not only leave easy work for last workers + random.seed(666) + random.shuffle(prompts_all) + + return prompts_all + + +def padded_mel_batch(ref_mels): + max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax() + padded_ref_mels = [] + for mel in ref_mels: + padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0) + padded_ref_mels.append(padded_ref_mel) + padded_ref_mels = torch.stack(padded_ref_mels) + padded_ref_mels = padded_ref_mels.permute(0, 2, 1) + return padded_ref_mels + + +def get_seedtts_testset_metainfo(metalst): + f = open(metalst) + lines = f.readlines() + f.close() + metainfo = [] + for line in lines: + assert len(line.strip().split("|")) == 4 + utt, prompt_text, prompt_wav, gt_text = line.strip().split("|") + utt = Path(utt).stem + gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav") + if not os.path.isabs(prompt_wav): + prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav) + metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav)) + return metainfo + + +def main(): + args = get_parser() + + accelerator = Accelerator() + device = f"cuda:{accelerator.process_index}" + + metainfo = get_seedtts_testset_metainfo(args.manifest_file) + prompts_all = get_inference_prompt( + metainfo, + speed=1.0, + tokenizer="pinyin", + target_sample_rate=24_000, + n_mel_channels=100, + hop_length=256, + mel_spec_type="bigvgan", + target_rms=0.1, + use_truth_duration=False, + infer_batch_size=1, + ) + + vocoder = BigVGANInference.from_pretrained( + "./bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False + ) + vocoder = vocoder.eval().to(device) + + model = get_model(args).eval().to(device) + checkpoint = torch.load(args.model_path, map_location="cpu") + if "ema_model_state_dict" in checkpoint or "model_state_dict" in checkpoint: + model = load_F5_TTS_pretrained_checkpoint(model, args.model_path) + else: + _ = load_checkpoint( + args.model_path, + model=model, + ) + + os.makedirs(args.output_dir, exist_ok=True) + + accelerator.wait_for_everyone() + start = time.time() + + with accelerator.split_between_processes(prompts_all) as prompts: + for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process): + ( + utts, + ref_rms_list, + ref_mels, + ref_mel_lens, + total_mel_lens, + final_text_list, + ) = prompt + ref_mels = ref_mels.to(device) + ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device) + total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device) + + # Inference + with torch.inference_mode(): + generated, _ = model.sample( + cond=ref_mels, + text=final_text_list, + duration=total_mel_lens, + lens=ref_mel_lens, + steps=args.nfe, + cfg_strength=2.0, + sway_sampling_coef=args.swaysampling, + no_ref_audio=False, + seed=args.seed, + ) + for i, gen in enumerate(generated): + gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0) + gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32) + + generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu() + target_rms = 0.1 + target_sample_rate = 24_000 + if ref_rms_list[i] < target_rms: + generated_wave = generated_wave * ref_rms_list[i] / target_rms + torchaudio.save( + f"{args.output_dir}/{utts[i]}.wav", + generated_wave, + target_sample_rate, + ) + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + timediff = time.time() - start + print(f"Done batch inference in {timediff / 60 :.2f} minutes.") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/wenetspeech4tts/TTS/f5-tts/model/README.md b/egs/wenetspeech4tts/TTS/f5-tts/model/README.md new file mode 100644 index 0000000000..e4a7e2a7ca --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/model/README.md @@ -0,0 +1,3 @@ +# Introduction +Files in this folder are copied from +https://github.com/SWivid/F5-TTS/tree/main/src/f5_tts/model diff --git a/egs/wenetspeech4tts/TTS/f5-tts/model/cfm.py b/egs/wenetspeech4tts/TTS/f5-tts/model/cfm.py new file mode 100644 index 0000000000..349c7220e2 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/model/cfm.py @@ -0,0 +1,326 @@ +""" +ein notation: +b - batch +n - sequence +nt - text sequence +nw - raw wave length +d - dimension +""" + +from __future__ import annotations + +from random import random +from typing import Callable + +import torch +import torch.nn.functional as F +from model.modules import MelSpec +from model.utils import ( + default, + exists, + lens_to_mask, + list_str_to_idx, + list_str_to_tensor, + mask_from_frac_lengths, +) +from torch import nn +from torch.nn.utils.rnn import pad_sequence +from torchdiffeq import odeint + + +class CFM(nn.Module): + def __init__( + self, + transformer: nn.Module, + sigma=0.0, + odeint_kwargs: dict = dict( + # atol = 1e-5, + # rtol = 1e-5, + method="euler" # 'midpoint' + ), + audio_drop_prob=0.3, + cond_drop_prob=0.2, + num_channels=None, + mel_spec_module: nn.Module | None = None, + mel_spec_kwargs: dict = dict(), + frac_lengths_mask: tuple[float, float] = (0.7, 1.0), + vocab_char_map: dict[str:int] | None = None, + ): + super().__init__() + + self.frac_lengths_mask = frac_lengths_mask + + # mel spec + self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs)) + num_channels = default(num_channels, self.mel_spec.n_mel_channels) + self.num_channels = num_channels + + # classifier-free guidance + self.audio_drop_prob = audio_drop_prob + self.cond_drop_prob = cond_drop_prob + + # transformer + self.transformer = transformer + dim = transformer.dim + self.dim = dim + + # conditional flow related + self.sigma = sigma + + # sampling related + self.odeint_kwargs = odeint_kwargs + + # vocab map for tokenization + self.vocab_char_map = vocab_char_map + + @property + def device(self): + return next(self.parameters()).device + + @torch.no_grad() + def sample( + self, + cond: float["b n d"] | float["b nw"], # noqa: F722 + text: int["b nt"] | list[str], # noqa: F722 + duration: int | int["b"], # noqa: F821 + *, + lens: int["b"] | None = None, # noqa: F821 + steps=32, + cfg_strength=1.0, + sway_sampling_coef=None, + seed: int | None = None, + max_duration=4096, + vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722 + no_ref_audio=False, + duplicate_test=False, + t_inter=0.1, + edit_mask=None, + ): + self.eval() + # raw wave + + if cond.ndim == 2: + cond = self.mel_spec(cond) + cond = cond.permute(0, 2, 1) + assert cond.shape[-1] == self.num_channels + + cond = cond.to(next(self.parameters()).dtype) + + batch, cond_seq_len, device = *cond.shape[:2], cond.device + if not exists(lens): + lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long) + + # text + + if isinstance(text, list): + if exists(self.vocab_char_map): + text = list_str_to_idx(text, self.vocab_char_map).to(device) + else: + text = list_str_to_tensor(text).to(device) + assert text.shape[0] == batch + + if exists(text): + text_lens = (text != -1).sum(dim=-1) + lens = torch.maximum( + text_lens, lens + ) # make sure lengths are at least those of the text characters + + # duration + + cond_mask = lens_to_mask(lens) + if edit_mask is not None: + cond_mask = cond_mask & edit_mask + + if isinstance(duration, int): + duration = torch.full((batch,), duration, device=device, dtype=torch.long) + + duration = torch.maximum( + lens + 1, duration + ) # just add one token so something is generated + duration = duration.clamp(max=max_duration) + max_duration = duration.amax() + + # duplicate test corner for inner time step oberservation + if duplicate_test: + test_cond = F.pad( + cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0 + ) + + cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0) + cond_mask = F.pad( + cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False + ) + cond_mask = cond_mask.unsqueeze(-1) + step_cond = torch.where( + cond_mask, cond, torch.zeros_like(cond) + ) # allow direct control (cut cond audio) with lens passed in + + if batch > 1: + mask = lens_to_mask(duration) + else: # save memory and speed up, as single inference need no mask currently + mask = None + + # test for no ref audio + if no_ref_audio: + cond = torch.zeros_like(cond) + + # neural ode + + def fn(t, x): + # at each step, conditioning is fixed + # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) + + # predict flow + pred = self.transformer( + x=x, + cond=step_cond, + text=text, + time=t, + mask=mask, + drop_audio_cond=False, + drop_text=False, + ) + if cfg_strength < 1e-5: + return pred + + null_pred = self.transformer( + x=x, + cond=step_cond, + text=text, + time=t, + mask=mask, + drop_audio_cond=True, + drop_text=True, + ) + return pred + (pred - null_pred) * cfg_strength + + # noise input + # to make sure batch inference result is same with different batch size, and for sure single inference + # still some difference maybe due to convolutional layers + y0 = [] + for dur in duration: + if exists(seed): + torch.manual_seed(seed) + y0.append( + torch.randn( + dur, self.num_channels, device=self.device, dtype=step_cond.dtype + ) + ) + y0 = pad_sequence(y0, padding_value=0, batch_first=True) + + t_start = 0 + + # duplicate test corner for inner time step oberservation + if duplicate_test: + t_start = t_inter + y0 = (1 - t_start) * y0 + t_start * test_cond + steps = int(steps * (1 - t_start)) + + t = torch.linspace( + t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype + ) + if sway_sampling_coef is not None: + t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) + + trajectory = odeint(fn, y0, t, **self.odeint_kwargs) + + sampled = trajectory[-1] + out = sampled + out = torch.where(cond_mask, cond, out) + + if exists(vocoder): + out = out.permute(0, 2, 1) + out = vocoder(out) + + return out, trajectory + + def forward( + self, + inp: float["b n d"] | float["b nw"], # mel or raw wave # noqa: F722 + text: int["b nt"] | list[str], # noqa: F722 + *, + lens: int["b"] | None = None, # noqa: F821 + noise_scheduler: str | None = None, + ): + # handle raw wave + if inp.ndim == 2: + inp = self.mel_spec(inp) + inp = inp.permute(0, 2, 1) + assert inp.shape[-1] == self.num_channels + + batch, seq_len, dtype, device, _σ1 = ( + *inp.shape[:2], + inp.dtype, + self.device, + self.sigma, + ) + + # handle text as string + if isinstance(text, list): + if exists(self.vocab_char_map): + text = list_str_to_idx(text, self.vocab_char_map).to(device) + else: + text = list_str_to_tensor(text).to(device) + assert text.shape[0] == batch + + # lens and mask + if not exists(lens): + lens = torch.full((batch,), seq_len, device=device) + + mask = lens_to_mask( + lens, length=seq_len + ) # useless here, as collate_fn will pad to max length in batch + + # get a random span to mask out for training conditionally + frac_lengths = ( + torch.zeros((batch,), device=self.device) + .float() + .uniform_(*self.frac_lengths_mask) + ) + rand_span_mask = mask_from_frac_lengths(lens, frac_lengths) + + if exists(mask): + rand_span_mask &= mask + + # mel is x1 + x1 = inp + + # x0 is gaussian noise + x0 = torch.randn_like(x1) + + # time step + time = torch.rand((batch,), dtype=dtype, device=self.device) + # TODO. noise_scheduler + + # sample xt (φ_t(x) in the paper) + t = time.unsqueeze(-1).unsqueeze(-1) + φ = (1 - t) * x0 + t * x1 + flow = x1 - x0 + + # only predict what is within the random mask span for infilling + cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1) + + # transformer and cfg training with a drop rate + drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper + if random() < self.cond_drop_prob: # p_uncond in voicebox paper + drop_audio_cond = True + drop_text = True + else: + drop_text = False + + # if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here + # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences + pred = self.transformer( + x=φ, + cond=cond, + text=text, + time=time, + drop_audio_cond=drop_audio_cond, + drop_text=drop_text, + ) + + # flow matching loss + loss = F.mse_loss(pred, flow, reduction="none") + loss = loss[rand_span_mask] + + return loss.mean(), cond, pred diff --git a/egs/wenetspeech4tts/TTS/f5-tts/model/dit.py b/egs/wenetspeech4tts/TTS/f5-tts/model/dit.py new file mode 100644 index 0000000000..966fabfdd4 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/model/dit.py @@ -0,0 +1,210 @@ +""" +ein notation: +b - batch +n - sequence +nt - text sequence +nw - raw wave length +d - dimension +""" + +from __future__ import annotations + +import torch +import torch.nn.functional as F +from model.modules import ( + AdaLayerNormZero_Final, + ConvNeXtV2Block, + ConvPositionEmbedding, + DiTBlock, + TimestepEmbedding, + get_pos_embed_indices, + precompute_freqs_cis, +) +from torch import nn +from x_transformers.x_transformers import RotaryEmbedding + +# Text embedding + + +class TextEmbedding(nn.Module): + def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2): + super().__init__() + self.text_embed = nn.Embedding( + text_num_embeds + 1, text_dim + ) # use 0 as filler token + + if conv_layers > 0: + self.extra_modeling = True + self.precompute_max_pos = 4096 # ~44s of 24khz audio + self.register_buffer( + "freqs_cis", + precompute_freqs_cis(text_dim, self.precompute_max_pos), + persistent=False, + ) + self.text_blocks = nn.Sequential( + *[ + ConvNeXtV2Block(text_dim, text_dim * conv_mult) + for _ in range(conv_layers) + ] + ) + else: + self.extra_modeling = False + + def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722 + text = ( + text + 1 + ) # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() + text = text[ + :, :seq_len + ] # curtail if character tokens are more than the mel spec tokens + batch, text_len = text.shape[0], text.shape[1] + text = F.pad(text, (0, seq_len - text_len), value=0) + + if drop_text: # cfg for text + text = torch.zeros_like(text) + + text = self.text_embed(text) # b n -> b n d + + # possible extra modeling + if self.extra_modeling: + # sinus pos emb + batch_start = torch.zeros((batch,), dtype=torch.long) + pos_idx = get_pos_embed_indices( + batch_start, seq_len, max_pos=self.precompute_max_pos + ) + text_pos_embed = self.freqs_cis[pos_idx] + text = text + text_pos_embed + + # convnextv2 blocks + text = self.text_blocks(text) + + return text + + +# noised input audio and context mixing embedding + + +class InputEmbedding(nn.Module): + def __init__(self, mel_dim, text_dim, out_dim): + super().__init__() + self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim) + self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim) + + def forward( + self, + x: float["b n d"], # noqa: F722 + cond: float["b n d"], # noqa: F722 + text_embed: float["b n d"], # noqa: F722 + drop_audio_cond=False, + ): + if drop_audio_cond: # cfg for cond audio + cond = torch.zeros_like(cond) + + x = self.proj(torch.cat((x, cond, text_embed), dim=-1)) + x = self.conv_pos_embed(x) + x + return x + + +# Transformer backbone using DiT blocks + + +class DiT(nn.Module): + def __init__( + self, + *, + dim, + depth=8, + heads=8, + dim_head=64, + dropout=0.1, + ff_mult=4, + mel_dim=100, + text_num_embeds=256, + text_dim=None, + conv_layers=0, + long_skip_connection=False, + checkpoint_activations=False, + ): + super().__init__() + + self.time_embed = TimestepEmbedding(dim) + if text_dim is None: + text_dim = mel_dim + self.text_embed = TextEmbedding( + text_num_embeds, text_dim, conv_layers=conv_layers + ) + self.input_embed = InputEmbedding(mel_dim, text_dim, dim) + + self.rotary_embed = RotaryEmbedding(dim_head) + + self.dim = dim + self.depth = depth + + self.transformer_blocks = nn.ModuleList( + [ + DiTBlock( + dim=dim, + heads=heads, + dim_head=dim_head, + ff_mult=ff_mult, + dropout=dropout, + ) + for _ in range(depth) + ] + ) + self.long_skip_connection = ( + nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None + ) + + self.norm_out = AdaLayerNormZero_Final(dim) # final modulation + self.proj_out = nn.Linear(dim, mel_dim) + + self.checkpoint_activations = checkpoint_activations + + def ckpt_wrapper(self, module): + # https://github.com/chuanyangjin/fast-DiT/blob/main/models.py + def ckpt_forward(*inputs): + outputs = module(*inputs) + return outputs + + return ckpt_forward + + def forward( + self, + x: float["b n d"], # nosied input audio # noqa: F722 + cond: float["b n d"], # masked cond audio # noqa: F722 + text: int["b nt"], # text # noqa: F722 + time: float["b"] | float[""], # time step # noqa: F821 F722 + drop_audio_cond, # cfg for cond audio + drop_text, # cfg for text + mask: bool["b n"] | None = None, # noqa: F722 + ): + batch, seq_len = x.shape[0], x.shape[1] + if time.ndim == 0: + time = time.repeat(batch) + + # t: conditioning time, c: context (text + masked cond audio), x: noised input audio + t = self.time_embed(time) + text_embed = self.text_embed(text, seq_len, drop_text=drop_text) + x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond) + + rope = self.rotary_embed.forward_from_seq_len(seq_len) + + if self.long_skip_connection is not None: + residual = x + + for block in self.transformer_blocks: + if self.checkpoint_activations: + x = torch.utils.checkpoint.checkpoint( + self.ckpt_wrapper(block), x, t, mask, rope + ) + else: + x = block(x, t, mask=mask, rope=rope) + + if self.long_skip_connection is not None: + x = self.long_skip_connection(torch.cat((x, residual), dim=-1)) + + x = self.norm_out(x, t) + output = self.proj_out(x) + + return output diff --git a/egs/wenetspeech4tts/TTS/f5-tts/model/modules.py b/egs/wenetspeech4tts/TTS/f5-tts/model/modules.py new file mode 100644 index 0000000000..05299d419e --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/model/modules.py @@ -0,0 +1,728 @@ +""" +ein notation: +b - batch +n - sequence +nt - text sequence +nw - raw wave length +d - dimension +""" + +from __future__ import annotations + +import math +from typing import Optional + +import torch +import torch.nn.functional as F +import torchaudio +from librosa.filters import mel as librosa_mel_fn +from torch import nn +from x_transformers.x_transformers import apply_rotary_pos_emb + +# raw wav to mel spec + + +mel_basis_cache = {} +hann_window_cache = {} + + +def get_bigvgan_mel_spectrogram( + waveform, + n_fft=1024, + n_mel_channels=100, + target_sample_rate=24000, + hop_length=256, + win_length=1024, + fmin=0, + fmax=None, + center=False, +): # Copy from https://github.com/NVIDIA/BigVGAN/tree/main + device = waveform.device + key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{fmin}_{fmax}_{device}" + + if key not in mel_basis_cache: + mel = librosa_mel_fn( + sr=target_sample_rate, + n_fft=n_fft, + n_mels=n_mel_channels, + fmin=fmin, + fmax=fmax, + ) + mel_basis_cache[key] = ( + torch.from_numpy(mel).float().to(device) + ) # TODO: why they need .float()? + hann_window_cache[key] = torch.hann_window(win_length).to(device) + + mel_basis = mel_basis_cache[key] + hann_window = hann_window_cache[key] + + padding = (n_fft - hop_length) // 2 + waveform = torch.nn.functional.pad( + waveform.unsqueeze(1), (padding, padding), mode="reflect" + ).squeeze(1) + + spec = torch.stft( + waveform, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=hann_window, + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9) + + mel_spec = torch.matmul(mel_basis, spec) + mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5)) + + return mel_spec + + +def get_vocos_mel_spectrogram( + waveform, + n_fft=1024, + n_mel_channels=100, + target_sample_rate=24000, + hop_length=256, + win_length=1024, +): + mel_stft = torchaudio.transforms.MelSpectrogram( + sample_rate=target_sample_rate, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + n_mels=n_mel_channels, + power=1, + center=True, + normalized=False, + norm=None, + ).to(waveform.device) + if len(waveform.shape) == 3: + waveform = waveform.squeeze(1) # 'b 1 nw -> b nw' + + assert len(waveform.shape) == 2 + + mel = mel_stft(waveform) + mel = mel.clamp(min=1e-5).log() + return mel + + +class MelSpec(nn.Module): + def __init__( + self, + n_fft=1024, + hop_length=256, + win_length=1024, + n_mel_channels=100, + target_sample_rate=24_000, + mel_spec_type="vocos", + ): + super().__init__() + assert mel_spec_type in ["vocos", "bigvgan"], print( + "We only support two extract mel backend: vocos or bigvgan" + ) + + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.n_mel_channels = n_mel_channels + self.target_sample_rate = target_sample_rate + + if mel_spec_type == "vocos": + self.extractor = get_vocos_mel_spectrogram + elif mel_spec_type == "bigvgan": + self.extractor = get_bigvgan_mel_spectrogram + + self.register_buffer("dummy", torch.tensor(0), persistent=False) + + def forward(self, wav): + if self.dummy.device != wav.device: + self.to(wav.device) + + mel = self.extractor( + waveform=wav, + n_fft=self.n_fft, + n_mel_channels=self.n_mel_channels, + target_sample_rate=self.target_sample_rate, + hop_length=self.hop_length, + win_length=self.win_length, + ) + + return mel + + +# sinusoidal position embedding + + +class SinusPositionEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x, scale=1000): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) + emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +# convolutional position embedding + + +class ConvPositionEmbedding(nn.Module): + def __init__(self, dim, kernel_size=31, groups=16): + super().__init__() + assert kernel_size % 2 != 0 + self.conv1d = nn.Sequential( + nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2), + nn.Mish(), + nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2), + nn.Mish(), + ) + + def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722 + if mask is not None: + mask = mask[..., None] + x = x.masked_fill(~mask, 0.0) + + x = x.permute(0, 2, 1) + x = self.conv1d(x) + out = x.permute(0, 2, 1) + + if mask is not None: + out = out.masked_fill(~mask, 0.0) + + return out + + +# rotary positional embedding related + + +def precompute_freqs_cis( + dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0 +): + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py + theta *= theta_rescale_factor ** (dim / (dim - 2)) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) # type: ignore + freqs = torch.outer(t, freqs).float() # type: ignore + freqs_cos = torch.cos(freqs) # real part + freqs_sin = torch.sin(freqs) # imaginary part + return torch.cat([freqs_cos, freqs_sin], dim=-1) + + +def get_pos_embed_indices(start, length, max_pos, scale=1.0): + # length = length if isinstance(length, int) else length.max() + scale = scale * torch.ones_like( + start, dtype=torch.float32 + ) # in case scale is a scalar + pos = ( + start.unsqueeze(1) + + ( + torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) + * scale.unsqueeze(1) + ).long() + ) + # avoid extra long error. + pos = torch.where(pos < max_pos, pos, max_pos - 1) + return pos + + +# Global Response Normalization layer (Instance Normalization ?) + + +class GRN(nn.Module): + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, dim)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=1, keepdim=True) + Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * Nx) + self.beta + x + + +# ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py +# ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108 + + +class ConvNeXtV2Block(nn.Module): + def __init__( + self, + dim: int, + intermediate_dim: int, + dilation: int = 1, + ): + super().__init__() + padding = (dilation * (7 - 1)) // 2 + self.dwconv = nn.Conv1d( + dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation + ) # depthwise conv + self.norm = nn.LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear( + dim, intermediate_dim + ) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.grn = GRN(intermediate_dim) + self.pwconv2 = nn.Linear(intermediate_dim, dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = x.transpose(1, 2) # b n d -> b d n + x = self.dwconv(x) + x = x.transpose(1, 2) # b d n -> b n d + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.grn(x) + x = self.pwconv2(x) + return residual + x + + +# AdaLayerNormZero +# return with modulated x for attn input, and params for later mlp modulation + + +class AdaLayerNormZero(nn.Module): + def __init__(self, dim): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(dim, dim * 6) + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, x, emb=None): + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk( + emb, 6, dim=1 + ) + + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp + + +# AdaLayerNormZero for final layer +# return only with modulated x for attn input, cuz no more mlp modulation + + +class AdaLayerNormZero_Final(nn.Module): + def __init__(self, dim): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(dim, dim * 2) + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, x, emb): + emb = self.linear(self.silu(emb)) + scale, shift = torch.chunk(emb, 2, dim=1) + + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x + + +# FeedForward + + +class FeedForward(nn.Module): + def __init__( + self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none" + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + activation = nn.GELU(approximate=approximate) + project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation) + self.ff = nn.Sequential( + project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.ff(x) + + +# Attention with possible joint part +# modified from diffusers/src/diffusers/models/attention_processor.py + + +class Attention(nn.Module): + def __init__( + self, + processor: JointAttnProcessor | AttnProcessor, + dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + context_dim: Optional[int] = None, # if not None -> joint attention + context_pre_only=None, + ): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + self.processor = processor + + self.dim = dim + self.heads = heads + self.inner_dim = dim_head * heads + self.dropout = dropout + + self.context_dim = context_dim + self.context_pre_only = context_pre_only + + self.to_q = nn.Linear(dim, self.inner_dim) + self.to_k = nn.Linear(dim, self.inner_dim) + self.to_v = nn.Linear(dim, self.inner_dim) + + if self.context_dim is not None: + self.to_k_c = nn.Linear(context_dim, self.inner_dim) + self.to_v_c = nn.Linear(context_dim, self.inner_dim) + if self.context_pre_only is not None: + self.to_q_c = nn.Linear(context_dim, self.inner_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(self.inner_dim, dim)) + self.to_out.append(nn.Dropout(dropout)) + + if self.context_pre_only is not None and not self.context_pre_only: + self.to_out_c = nn.Linear(self.inner_dim, dim) + + def forward( + self, + x: float["b n d"], # noised input x # noqa: F722 + c: float["b n d"] = None, # context c # noqa: F722 + mask: bool["b n"] | None = None, # noqa: F722 + rope=None, # rotary position embedding for x + c_rope=None, # rotary position embedding for c + ) -> torch.Tensor: + if c is not None: + return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope) + else: + return self.processor(self, x, mask=mask, rope=rope) + + +# Attention processor + + +class AttnProcessor: + def __init__(self): + pass + + def __call__( + self, + attn: Attention, + x: float["b n d"], # noised input x # noqa: F722 + mask: bool["b n"] | None = None, # noqa: F722 + rope=None, # rotary position embedding + ) -> torch.FloatTensor: + batch_size = x.shape[0] + + # `sample` projections. + query = attn.to_q(x) + key = attn.to_k(x) + value = attn.to_v(x) + + # apply rotary position embedding + if rope is not None: + freqs, xpos_scale = rope + q_xpos_scale, k_xpos_scale = ( + (xpos_scale, xpos_scale**-1.0) + if xpos_scale is not None + else (1.0, 1.0) + ) + + query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) + key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) + + # attention + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # mask. e.g. inference got a batch with different target durations, mask out the padding + if mask is not None: + attn_mask = mask + attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n' + attn_mask = attn_mask.expand( + batch_size, attn.heads, query.shape[-2], key.shape[-2] + ) + else: + attn_mask = None + + x = F.scaled_dot_product_attention( + query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False + ) + x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + x = x.to(query.dtype) + + # linear proj + x = attn.to_out[0](x) + # dropout + x = attn.to_out[1](x) + + if mask is not None: + mask = mask.unsqueeze(-1) + x = x.masked_fill(~mask, 0.0) + + return x + + +# Joint Attention processor for MM-DiT +# modified from diffusers/src/diffusers/models/attention_processor.py + + +class JointAttnProcessor: + def __init__(self): + pass + + def __call__( + self, + attn: Attention, + x: float["b n d"], # noised input x # noqa: F722 + c: float["b nt d"] = None, # context c, here text # noqa: F722 + mask: bool["b n"] | None = None, # noqa: F722 + rope=None, # rotary position embedding for x + c_rope=None, # rotary position embedding for c + ) -> torch.FloatTensor: + residual = x + + batch_size = c.shape[0] + + # `sample` projections. + query = attn.to_q(x) + key = attn.to_k(x) + value = attn.to_v(x) + + # `context` projections. + c_query = attn.to_q_c(c) + c_key = attn.to_k_c(c) + c_value = attn.to_v_c(c) + + # apply rope for context and noised input independently + if rope is not None: + freqs, xpos_scale = rope + q_xpos_scale, k_xpos_scale = ( + (xpos_scale, xpos_scale**-1.0) + if xpos_scale is not None + else (1.0, 1.0) + ) + query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) + key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) + if c_rope is not None: + freqs, xpos_scale = c_rope + q_xpos_scale, k_xpos_scale = ( + (xpos_scale, xpos_scale**-1.0) + if xpos_scale is not None + else (1.0, 1.0) + ) + c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale) + c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale) + + # attention + query = torch.cat([query, c_query], dim=1) + key = torch.cat([key, c_key], dim=1) + value = torch.cat([value, c_value], dim=1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # mask. e.g. inference got a batch with different target durations, mask out the padding + if mask is not None: + attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text) + attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n' + attn_mask = attn_mask.expand( + batch_size, attn.heads, query.shape[-2], key.shape[-2] + ) + else: + attn_mask = None + + x = F.scaled_dot_product_attention( + query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False + ) + x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + x = x.to(query.dtype) + + # Split the attention outputs. + x, c = ( + x[:, : residual.shape[1]], + x[:, residual.shape[1] :], + ) + + # linear proj + x = attn.to_out[0](x) + # dropout + x = attn.to_out[1](x) + if not attn.context_pre_only: + c = attn.to_out_c(c) + + if mask is not None: + mask = mask.unsqueeze(-1) + x = x.masked_fill(~mask, 0.0) + # c = c.masked_fill(~mask, 0.) # no mask for c (text) + + return x, c + + +# DiT Block + + +class DiTBlock(nn.Module): + def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1): + super().__init__() + + self.attn_norm = AdaLayerNormZero(dim) + self.attn = Attention( + processor=AttnProcessor(), + dim=dim, + heads=heads, + dim_head=dim_head, + dropout=dropout, + ) + + self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward( + dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh" + ) + + def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding + # pre-norm & modulation for attention input + norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t) + + # attention + attn_output = self.attn(x=norm, mask=mask, rope=rope) + + # process attention output for input x + x = x + gate_msa.unsqueeze(1) * attn_output + + norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ff_output = self.ff(norm) + x = x + gate_mlp.unsqueeze(1) * ff_output + + return x + + +# MMDiT Block https://arxiv.org/abs/2403.03206 + + +class MMDiTBlock(nn.Module): + r""" + modified from diffusers/src/diffusers/models/attention.py + + notes. + _c: context related. text, cond, etc. (left part in sd3 fig2.b) + _x: noised input related. (right part) + context_pre_only: last layer only do prenorm + modulation cuz no more ffn + """ + + def __init__( + self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False + ): + super().__init__() + + self.context_pre_only = context_pre_only + + self.attn_norm_c = ( + AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim) + ) + self.attn_norm_x = AdaLayerNormZero(dim) + self.attn = Attention( + processor=JointAttnProcessor(), + dim=dim, + heads=heads, + dim_head=dim_head, + dropout=dropout, + context_dim=dim, + context_pre_only=context_pre_only, + ) + + if not context_pre_only: + self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_c = FeedForward( + dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh" + ) + else: + self.ff_norm_c = None + self.ff_c = None + self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_x = FeedForward( + dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh" + ) + + def forward( + self, x, c, t, mask=None, rope=None, c_rope=None + ): # x: noised input, c: context, t: time embedding + # pre-norm & modulation for attention input + if self.context_pre_only: + norm_c = self.attn_norm_c(c, t) + else: + norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c( + c, emb=t + ) + norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x( + x, emb=t + ) + + # attention + x_attn_output, c_attn_output = self.attn( + x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope + ) + + # process attention output for context c + if self.context_pre_only: + c = None + else: # if not last layer + c = c + c_gate_msa.unsqueeze(1) * c_attn_output + + norm_c = ( + self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + ) + c_ff_output = self.ff_c(norm_c) + c = c + c_gate_mlp.unsqueeze(1) * c_ff_output + + # process attention output for input x + x = x + x_gate_msa.unsqueeze(1) * x_attn_output + + norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None] + x_ff_output = self.ff_x(norm_x) + x = x + x_gate_mlp.unsqueeze(1) * x_ff_output + + return c, x + + +# time step conditioning embedding + + +class TimestepEmbedding(nn.Module): + def __init__(self, dim, freq_embed_dim=256): + super().__init__() + self.time_embed = SinusPositionEmbedding(freq_embed_dim) + self.time_mlp = nn.Sequential( + nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim) + ) + + def forward(self, timestep: float["b"]): # noqa: F821 + time_hidden = self.time_embed(timestep) + time_hidden = time_hidden.to(timestep.dtype) + time = self.time_mlp(time_hidden) # b d + return time diff --git a/egs/wenetspeech4tts/TTS/f5-tts/model/utils.py b/egs/wenetspeech4tts/TTS/f5-tts/model/utils.py new file mode 100644 index 0000000000..fae5fadb61 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/model/utils.py @@ -0,0 +1,206 @@ +from __future__ import annotations + +import os +import random +from collections import defaultdict +from importlib.resources import files + +import jieba +import torch +from pypinyin import Style, lazy_pinyin +from torch.nn.utils.rnn import pad_sequence + +# seed everything + + +def seed_everything(seed=0): + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +# helpers + + +def exists(v): + return v is not None + + +def default(v, d): + return v if exists(v) else d + + +# tensor helpers + + +def lens_to_mask( + t: int["b"], length: int | None = None # noqa: F722 F821 +) -> bool["b n"]: # noqa: F722 F821 + if not exists(length): + length = t.amax() + + seq = torch.arange(length, device=t.device) + return seq[None, :] < t[:, None] + + +def mask_from_start_end_indices( + seq_len: int["b"], start: int["b"], end: int["b"] # noqa: F722 F821 +): + max_seq_len = seq_len.max().item() + seq = torch.arange(max_seq_len, device=start.device).long() + start_mask = seq[None, :] >= start[:, None] + end_mask = seq[None, :] < end[:, None] + return start_mask & end_mask + + +def mask_from_frac_lengths( + seq_len: int["b"], frac_lengths: float["b"] # noqa: F722 F821 +): + lengths = (frac_lengths * seq_len).long() + max_start = seq_len - lengths + + rand = torch.rand_like(frac_lengths) + start = (max_start * rand).long().clamp(min=0) + end = start + lengths + + return mask_from_start_end_indices(seq_len, start, end) + + +def maybe_masked_mean( + t: float["b n d"], mask: bool["b n"] = None # noqa: F722 F821 +) -> float["b d"]: # noqa: F722 F821 + if not exists(mask): + return t.mean(dim=1) + + t = torch.where(mask[:, :, None], t, torch.tensor(0.0, device=t.device)) + num = t.sum(dim=1) + den = mask.float().sum(dim=1) + + return num / den.clamp(min=1.0) + + +# simple utf-8 tokenizer, since paper went character based +def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722 + list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style + text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True) + return text + + +# char tokenizer, based on custom dataset's extracted .txt file +def list_str_to_idx( + text: list[str] | list[list[str]], + vocab_char_map: dict[str, int], # {char: idx} + padding_value=-1, +) -> int["b nt"]: # noqa: F722 + list_idx_tensors = [ + torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text + ] # pinyin or char style + text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True) + return text + + +# Get tokenizer + + +def get_tokenizer(dataset_name, tokenizer: str = "pinyin"): + """ + tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file + - "char" for char-wise tokenizer, need .txt vocab_file + - "byte" for utf-8 tokenizer + - "custom" if you're directly passing in a path to the vocab.txt you want to use + vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols + - if use "char", derived from unfiltered character & symbol counts of custom dataset + - if use "byte", set to 256 (unicode byte range) + """ + if tokenizer in ["pinyin", "char"]: + tokenizer_path = os.path.join( + files("f5_tts").joinpath("../../data"), + f"{dataset_name}_{tokenizer}/vocab.txt", + ) + with open(tokenizer_path, "r", encoding="utf-8") as f: + vocab_char_map = {} + for i, char in enumerate(f): + vocab_char_map[char[:-1]] = i + vocab_size = len(vocab_char_map) + assert ( + vocab_char_map[" "] == 0 + ), "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char" + + elif tokenizer == "byte": + vocab_char_map = None + vocab_size = 256 + + elif tokenizer == "custom": + with open(dataset_name, "r", encoding="utf-8") as f: + vocab_char_map = {} + for i, char in enumerate(f): + vocab_char_map[char[:-1]] = i + vocab_size = len(vocab_char_map) + + return vocab_char_map, vocab_size + + +# convert char to pinyin + +jieba.initialize() +print("Word segmentation module jieba initialized.\n") + + +def convert_char_to_pinyin(text_list, polyphone=True): + final_text_list = [] + custom_trans = str.maketrans( + {";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"} + ) # add custom trans here, to address oov + + def is_chinese(c): + return "\u3100" <= c <= "\u9fff" # common chinese characters + + for text in text_list: + char_list = [] + text = text.translate(custom_trans) + for seg in jieba.cut(text): + seg_byte_len = len(bytes(seg, "UTF-8")) + if seg_byte_len == len(seg): # if pure alphabets and symbols + if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"": + char_list.append(" ") + char_list.extend(seg) + elif polyphone and seg_byte_len == 3 * len( + seg + ): # if pure east asian characters + seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True) + for i, c in enumerate(seg): + if is_chinese(c): + char_list.append(" ") + char_list.append(seg_[i]) + else: # if mixed characters, alphabets and symbols + for c in seg: + if ord(c) < 256: + char_list.extend(c) + elif is_chinese(c): + char_list.append(" ") + char_list.extend( + lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True) + ) + else: + char_list.append(c) + final_text_list.append(char_list) + + return final_text_list + + +# filter func for dirty data with many repetitions + + +def repetition_found(text, length=2, tolerance=10): + pattern_count = defaultdict(int) + for i in range(len(text) - length + 1): + pattern = text[i : i + length] + pattern_count[pattern] += 1 + for pattern, count in pattern_count.items(): + if count > tolerance: + return True + return False diff --git a/egs/wenetspeech4tts/TTS/f5-tts/speech_synthesis.py b/egs/wenetspeech4tts/TTS/f5-tts/speech_synthesis.py new file mode 100644 index 0000000000..57f677fcb3 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/speech_synthesis.py @@ -0,0 +1,104 @@ +from typing import Callable, Dict, List, Sequence, Union + +import torch +from lhotse import validate +from lhotse.cut import CutSet +from lhotse.dataset.collation import collate_audio +from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures +from lhotse.utils import ifnone + + +class SpeechSynthesisDataset(torch.utils.data.Dataset): + """ + The PyTorch Dataset for the speech synthesis task. + Each item in this dataset is a dict of: + + .. code-block:: + + { + 'audio': (B x NumSamples) float tensor + 'features': (B x NumFrames x NumFeatures) float tensor + 'audio_lens': (B, ) int tensor + 'features_lens': (B, ) int tensor + 'text': List[str] of len B # when return_text=True + 'tokens': List[List[str]] # when return_tokens=True + 'speakers': List[str] of len B # when return_spk_ids=True + 'cut': List of Cuts # when return_cuts=True + } + """ + + def __init__( + self, + cut_transforms: List[Callable[[CutSet], CutSet]] = None, + feature_input_strategy: BatchIO = PrecomputedFeatures(), + feature_transforms: Union[Sequence[Callable], Callable] = None, + return_text: bool = True, + return_tokens: bool = False, + return_spk_ids: bool = False, + return_cuts: bool = False, + ) -> None: + super().__init__() + + self.cut_transforms = ifnone(cut_transforms, []) + self.feature_input_strategy = feature_input_strategy + + self.return_text = return_text + self.return_tokens = return_tokens + self.return_spk_ids = return_spk_ids + self.return_cuts = return_cuts + + if feature_transforms is None: + feature_transforms = [] + elif not isinstance(feature_transforms, Sequence): + feature_transforms = [feature_transforms] + + assert all( + isinstance(transform, Callable) for transform in feature_transforms + ), "Feature transforms must be Callable" + self.feature_transforms = feature_transforms + + def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]: + validate_for_tts(cuts) + + for transform in self.cut_transforms: + cuts = transform(cuts) + + # audio, audio_lens = collate_audio(cuts) + features, features_lens = self.feature_input_strategy(cuts) + + for transform in self.feature_transforms: + features = transform(features) + + batch = { + # "audio": audio, + "features": features, + # "audio_lens": audio_lens, + "features_lens": features_lens, + } + + if self.return_text: + # use normalized text + # text = [cut.supervisions[0].normalized_text for cut in cuts] + text = [cut.supervisions[0].text for cut in cuts] + batch["text"] = text + + if self.return_tokens: + # tokens = [cut.tokens for cut in cuts] + tokens = [cut.supervisions[0].custom["tokens"]["text"] for cut in cuts] + batch["tokens"] = tokens + + if self.return_spk_ids: + batch["speakers"] = [cut.supervisions[0].speaker for cut in cuts] + + if self.return_cuts: + batch["cut"] = [cut for cut in cuts] + + return batch + + +def validate_for_tts(cuts: CutSet) -> None: + validate(cuts) + for cut in cuts: + assert ( + len(cut.supervisions) == 1 + ), "Only the Cuts with single supervision are supported." diff --git a/egs/wenetspeech4tts/TTS/f5-tts/train.py b/egs/wenetspeech4tts/TTS/f5-tts/train.py new file mode 100755 index 0000000000..37dcf531e5 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/train.py @@ -0,0 +1,1178 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo) +# Copyright 2023 (authors: Feiteng Li) +# Copyright 2024 (authors: Yuekai Zhang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +# docker: ghcr.io/swivid/f5-tts:main +# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html +# pip install kaldialign lhotse tensorboard bigvganinference sentencepiece + +world_size=8 +exp_dir=exp/f5-tts-small +python3 f5-tts/train.py --max-duration 700 --filter-min-duration 0.5 --filter-max-duration 20 \ + --num-buckets 6 --dtype "bfloat16" --save-every-n 5000 --valid-interval 10000 \ + --base-lr 7.5e-5 --warmup-steps 20000 --num-epochs 60 \ + --num-decoder-layers 18 --nhead 12 --decoder-dim 768 \ + --exp-dir ${exp_dir} --world-size ${world_size} +""" + +import argparse +import copy +import logging +import os +import random +import warnings +from contextlib import nullcontext +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from lhotse import CutSet +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model.cfm import CFM +from model.dit import DiT +from model.utils import convert_char_to_pinyin +from torch import Tensor +from torch.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim.lr_scheduler import LinearLR, SequentialLR +from torch.utils.tensorboard import SummaryWriter +from tts_datamodule import TtsDataModule +from utils import MetricsTracker + +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, setup_logger, str2bool # MetricsTracker + +LRSchedulerType = torch.optim.lr_scheduler._LRScheduler + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--decoder-dim", + type=int, + default=1024, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--nhead", + type=int, + default=16, + help="Number of attention heads in the Decoder layers.", + ) + + parser.add_argument( + "--num-decoder-layers", + type=int, + default=22, + help="Number of Decoder layers.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=20, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=Path, + default="exp/f5", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="f5-tts/vocab.txt", + help="Path to the unique text tokens file", + ) + + parser.add_argument( + "--pretrained-model-path", + type=str, + default=None, + help="Path to file", + ) + + parser.add_argument( + "--optimizer-name", + type=str, + default="AdamW", + help="The optimizer.", + ) + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + parser.add_argument( + "--warmup-steps", + type=int, + default=200, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--decay-steps", + type=int, + default=1000000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=10000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train %% save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + parser.add_argument( + "--valid-interval", + type=int, + default=10000, + help="""Run validation if batch_idx %% valid_interval is 0.""", + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=20, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=0, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--accumulate-grad-steps", + type=int, + default=1, + help="""update gradient when batch_idx_train %% accumulate_grad_steps == 0. + """, + ) + + parser.add_argument( + "--dtype", + type=str, + default="bfloat16", + help="Training dtype: float32 bfloat16 float16.", + ) + + parser.add_argument( + "--filter-min-duration", + type=float, + default=0.0, + help="Keep only utterances with duration > this.", + ) + + parser.add_argument( + "--filter-max-duration", + type=float, + default=20.0, + help="Keep only utterances with duration < this.", + ) + + parser.add_argument( + "--oom-check", + type=str2bool, + default=False, + help="perform OOM check on dataloader batches before starting training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 100, + "reset_interval": 200, + "valid_interval": 10000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_tokenizer(vocab_file_path: str): + """ + tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file + - "char" for char-wise tokenizer, need .txt vocab_file + - "byte" for utf-8 tokenizer + - "custom" if you're directly passing in a path to the vocab.txt you want to use + vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols + - if use "char", derived from unfiltered character & symbol counts of custom dataset + - if use "byte", set to 256 (unicode byte range) + """ + with open(vocab_file_path, "r", encoding="utf-8") as f: + vocab_char_map = {} + for i, char in enumerate(f): + vocab_char_map[char[:-1]] = i + vocab_size = len(vocab_char_map) + + return vocab_char_map, vocab_size + + +def get_model(params): + vocab_char_map, vocab_size = get_tokenizer(params.tokens) + # bigvgan 100 dim features + n_mel_channels = 100 + n_fft = 1024 + sampling_rate = 24_000 + hop_length = 256 + win_length = 1024 + + model_cfg = { + "dim": params.decoder_dim, + "depth": params.num_decoder_layers, + "heads": params.nhead, + "ff_mult": 2, + "text_dim": 512, + "conv_layers": 4, + "checkpoint_activations": False, + } + model = CFM( + transformer=DiT( + **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels + ), + mel_spec_kwargs=dict( + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + n_mel_channels=n_mel_channels, + target_sample_rate=sampling_rate, + mel_spec_type="bigvgan", + ), + odeint_kwargs=dict( + method="euler", + ), + vocab_char_map=vocab_char_map, + ) + return model + + +def load_F5_TTS_pretrained_checkpoint( + model, ckpt_path, device: str = "cpu", dtype=torch.float32 +): + checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True) + if "ema_model_state_dict" in checkpoint: + checkpoint["model_state_dict"] = { + k.replace("ema_model.", ""): v + for k, v in checkpoint["ema_model_state_dict"].items() + if k not in ["initted", "step"] + } + + # patch for backward compatibility, 305e3ea + for key in [ + "mel_spec.mel_stft.mel_scale.fb", + "mel_spec.mel_stft.spectrogram.window", + ]: + if key in checkpoint["model_state_dict"]: + del checkpoint["model_state_dict"][key] + model.load_state_dict(checkpoint["model_state_dict"]) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + if isinstance(model, DDP): + raise ValueError("load_checkpoint before DDP") + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def prepare_input(batch: dict, device: torch.device): + """Parse batch data""" + text_inputs = batch["text"] + # texts.extend(convert_char_to_pinyin([text], polyphone=true)) + text_inputs = convert_char_to_pinyin(text_inputs, polyphone=True) + + mel_spec = batch["features"] + mel_lengths = batch["features_lens"] + return text_inputs, mel_spec.to(device), mel_lengths.to(device) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + (text_inputs, mel_spec, mel_lengths) = prepare_input(batch, device=device) + # at entry, TextTokens is (N, P) + + with torch.set_grad_enabled(is_training): + loss, cond, pred = model(mel_spec, text=text_inputs, lens=mel_lengths) + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["samples"] = mel_lengths.size(0) + + info["loss"] = loss.detach().cpu().item() * info["samples"] + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + tokenizer=tokenizer, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + if world_size > 1: + tot_loss.reduce(loss.device) + loss_value = tot_loss["loss"] / tot_loss["samples"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer, + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + rng: random.Random, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + rng: + Random for selecting. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + tot_loss = MetricsTracker() + iter_dl = iter(train_dl) + + dtype, enabled = torch.float32, False + if params.dtype in ["bfloat16", "bf16"]: + dtype, enabled = torch.bfloat16, True + elif params.dtype in ["float16", "fp16"]: + dtype, enabled = torch.float16, True + + batch_idx = 0 + while True: + try: + batch = next(iter_dl) + except StopIteration: + logging.info("Reaches end of dataloader.") + break + + batch_idx += 1 + + params.batch_idx_train += 1 + batch_size = len(batch["text"]) + + try: + with torch.amp.autocast("cuda", dtype=dtype, enabled=enabled): + loss, loss_info = compute_loss( + params=params, + model=model, + tokenizer=tokenizer, + batch=batch, + is_training=True, + ) + + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info * ( + 1 / params.reset_interval + ) + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + + scaler.scale(loss).backward() + if params.batch_idx_train >= params.accumulate_grad_steps: + if params.batch_idx_train % params.accumulate_grad_steps == 0: + + # Unscales the gradients of optimizer's assigned params in-place + scaler.unscale_(optimizer) + # Since the gradients of optimizer's assigned params are unscaled, clips as usual: + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + # loss.backward() + # optimizer.step() + + for k in range(params.accumulate_grad_steps): + scheduler.step() + + set_batch_count(model, params.batch_idx_train) + except: # noqa + display_and_save_batch(batch, params=params) + raise + + if params.average_period > 0: + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + # Perform Operation in rank 0 + if rank == 0: + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + # Perform Operation in rank 0 + if rank == 0: + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.dtype in ["float16", "fp16"]: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = ( + scaler._scale.item() if params.dtype in ["float16", "fp16"] else 1.0 + ) + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, train_loss[{loss_info}], " + f"batch size: {batch_size}, " + f"lr: {cur_lr:.2e}" + + ( + f", grad_scale: {cur_grad_scale}" + if params.dtype in ["float16", "fp16"] + else "" + ) + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + loss_info.write_summary( + tb_writer, + "train/current_", + params.batch_idx_train, + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.dtype in ["float16", "fp16"]: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if params.batch_idx_train % params.valid_interval == 0: + # Calculate validation loss in Rank 0 + model.eval() + logging.info("Computing validation loss") + with torch.amp.autocast("cuda", dtype=dtype): + valid_info = compute_validation_loss( + params=params, + model=model, + tokenizer=tokenizer, + valid_dl=valid_dl, + world_size=world_size, + ) + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + model.train() + + loss_value = tot_loss["loss"] / tot_loss["samples"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def filter_short_and_long_utterances( + cuts: CutSet, min_duration: float, max_duration: float +) -> CutSet: + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 0.6 second and 20 seconds + if c.duration < min_duration or c.duration > max_duration: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + return True + + cuts = cuts.filter(remove_short_and_long_utt) + + return cuts + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + rng = random.Random(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + # https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + + logging.info(f"Device: {device}") + tokenizer = get_tokenizer(params.tokens) + logging.info(params) + + logging.info("About to create model") + + model = get_model(params) + + if params.pretrained_model_path: + checkpoint = torch.load(params.pretrained_model_path, map_location="cpu") + if "ema_model_state_dict" in checkpoint or "model_state_dict" in checkpoint: + model = load_F5_TTS_pretrained_checkpoint( + model, params.pretrained_model_path + ) + else: + _ = load_checkpoint( + params.pretrained_model_path, + model=model, + ) + + model = model.to(device) + + with open(f"{params.exp_dir}/model.txt", "w") as f: + print(model) + print(model, file=f) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0 and params.average_period > 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=False) + + model_parameters = model.parameters() + + optimizer = torch.optim.AdamW( + model_parameters, + lr=params.base_lr, + betas=(0.9, 0.95), + weight_decay=1e-2, + eps=1e-8, + ) + + warmup_scheduler = LinearLR( + optimizer, start_factor=1e-8, end_factor=1.0, total_iters=params.warmup_steps + ) + decay_scheduler = LinearLR( + optimizer, start_factor=1.0, end_factor=1e-8, total_iters=params.decay_steps + ) + scheduler = SequentialLR( + optimizer, + schedulers=[warmup_scheduler, decay_scheduler], + milestones=[params.warmup_steps], + ) + + optimizer.zero_grad() + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.inf_check: + register_inf_check_hooks(model) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + dataset = TtsDataModule(args) + train_cuts = dataset.train_cuts() + valid_cuts = dataset.valid_cuts() + + train_cuts = filter_short_and_long_utterances( + train_cuts, params.filter_min_duration, params.filter_max_duration + ) + valid_cuts = filter_short_and_long_utterances( + valid_cuts, params.filter_min_duration, params.filter_max_duration + ) + + train_dl = dataset.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + valid_dl = dataset.valid_dataloaders(valid_cuts) + + if params.oom_check: + scan_pessimistic_batches_for_oom( + model=model, + tokenizer=tokenizer, + train_dl=train_dl, + optimizer=optimizer, + params=params, + ) + + scaler = GradScaler( + "cuda", enabled=(params.dtype in ["fp16", "float16"]), init_scale=1.0 + ) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + tokenizer=tokenizer, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + train_dl=train_dl, + valid_dl=valid_dl, + rng=rng, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + tokenizer, + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + dtype = torch.float32 + if params.dtype in ["bfloat16", "bf16"]: + dtype = torch.bfloat16 + elif params.dtype in ["float16", "fp16"]: + dtype = torch.float16 + + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + print(batch.keys()) + try: + with torch.amp.autocast("cuda", dtype=dtype): + loss, loss_info = compute_loss( + params=params, + model=model, + tokenizer=tokenizer, + batch=batch, + is_training=True, + ) + loss.backward(retain_graph=True) + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + TtsDataModule.add_arguments(parser) + args = parser.parse_args() + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/wenetspeech4tts/TTS/f5-tts/tts_datamodule.py b/egs/wenetspeech4tts/TTS/f5-tts/tts_datamodule.py new file mode 100644 index 0000000000..80ba173183 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/tts_datamodule.py @@ -0,0 +1,306 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures; SpeechSynthesisDataset, + CutConcatenate, + CutMix, + DynamicBucketingSampler, + PrecomputedFeatures, + SimpleCutSampler, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from speech_synthesis import SpeechSynthesisDataset # noqa F401 +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class TtsDataModule: + """ + DataModule for tts experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="TTS data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=False, + help="When enabled, each batch will have the " + "field: batch['cut'] with the cuts that " + "were used to construct it.", + ) + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + parser.add_argument( + "--prefix", + type=str, + default="wenetspeech4tts", + help="prefix of the manifest file", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to create train dataset") + train = SpeechSynthesisDataset( + return_text=True, + return_tokens=False, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + raise NotImplementedError( + "On-the-fly feature extraction is not implemented yet." + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=True, + pin_memory=True, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + raise NotImplementedError( + "On-the-fly feature extraction is not implemented yet." + ) + else: + validate = SpeechSynthesisDataset( + return_text=True, + return_tokens=False, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + shuffle=False, + ) + logging.info("About to create valid dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=True, + pin_memory=True, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.info("About to create test dataset") + if self.args.on_the_fly_feats: + raise NotImplementedError( + "On-the-fly feature extraction is not implemented yet." + ) + else: + test = SpeechSynthesisDataset( + return_text=True, + return_tokens=False, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + test_sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + shuffle=False, + ) + logging.info("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=test_sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy( + self.args.manifest_dir / f"{self.args.prefix}_cuts_train.jsonl.gz" + ) + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get validation cuts") + return load_manifest_lazy( + self.args.manifest_dir / f"{self.args.prefix}_cuts_valid.jsonl.gz" + ) + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + return load_manifest_lazy( + self.args.manifest_dir / f"{self.args.prefix}_cuts_test.jsonl.gz" + ) diff --git a/egs/wenetspeech4tts/TTS/f5-tts/utils.py b/egs/wenetspeech4tts/TTS/f5-tts/utils.py new file mode 120000 index 0000000000..ceaaea1963 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/utils.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/matcha/utils.py \ No newline at end of file diff --git a/egs/wenetspeech4tts/TTS/f5-tts/vocab.txt b/egs/wenetspeech4tts/TTS/f5-tts/vocab.txt new file mode 100644 index 0000000000..93f8b48b2c --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/vocab.txt @@ -0,0 +1,2545 @@ + +! +" +# +$ +% +& +' +( +) +* ++ +, +- +. +/ +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 +: +; += +> +? +@ +A +B +C +D +E +F +G +H +I +J +K +L +M +N +O +P +Q +R +S +T +U +V +W +X +Y +Z +[ +\ +] +_ +a +a1 +ai1 +ai2 +ai3 +ai4 +an1 +an3 +an4 +ang1 +ang2 +ang4 +ao1 +ao2 +ao3 +ao4 +b +ba +ba1 +ba2 +ba3 +ba4 +bai1 +bai2 +bai3 +bai4 +ban1 +ban2 +ban3 +ban4 +bang1 +bang2 +bang3 +bang4 +bao1 +bao2 +bao3 +bao4 +bei +bei1 +bei2 +bei3 +bei4 +ben1 +ben2 +ben3 +ben4 +beng +beng1 +beng2 +beng3 +beng4 +bi1 +bi2 +bi3 +bi4 +bian1 +bian2 +bian3 +bian4 +biao1 +biao2 +biao3 +bie1 +bie2 +bie3 +bie4 +bin1 +bin4 +bing1 +bing2 +bing3 +bing4 +bo +bo1 +bo2 +bo3 +bo4 +bu2 +bu3 +bu4 +c +ca1 +cai1 +cai2 +cai3 +cai4 +can1 +can2 +can3 +can4 +cang1 +cang2 +cao1 +cao2 +cao3 +ce4 +cen1 +cen2 +ceng1 +ceng2 +ceng4 +cha1 +cha2 +cha3 +cha4 +chai1 +chai2 +chan1 +chan2 +chan3 +chan4 +chang1 +chang2 +chang3 +chang4 +chao1 +chao2 +chao3 +che1 +che2 +che3 +che4 +chen1 +chen2 +chen3 +chen4 +cheng1 +cheng2 +cheng3 +cheng4 +chi1 +chi2 +chi3 +chi4 +chong1 +chong2 +chong3 +chong4 +chou1 +chou2 +chou3 +chou4 +chu1 +chu2 +chu3 +chu4 +chua1 +chuai1 +chuai2 +chuai3 +chuai4 +chuan1 +chuan2 +chuan3 +chuan4 +chuang1 +chuang2 +chuang3 +chuang4 +chui1 +chui2 +chun1 +chun2 +chun3 +chuo1 +chuo4 +ci1 +ci2 +ci3 +ci4 +cong1 +cong2 +cou4 +cu1 +cu4 +cuan1 +cuan2 +cuan4 +cui1 +cui3 +cui4 +cun1 +cun2 +cun4 +cuo1 +cuo2 +cuo4 +d +da +da1 +da2 +da3 +da4 +dai1 +dai2 +dai3 +dai4 +dan1 +dan2 +dan3 +dan4 +dang1 +dang2 +dang3 +dang4 +dao1 +dao2 +dao3 +dao4 +de +de1 +de2 +dei3 +den4 +deng1 +deng2 +deng3 +deng4 +di1 +di2 +di3 +di4 +dia3 +dian1 +dian2 +dian3 +dian4 +diao1 +diao3 +diao4 +die1 +die2 +die4 +ding1 +ding2 +ding3 +ding4 +diu1 +dong1 +dong3 +dong4 +dou1 +dou2 +dou3 +dou4 +du1 +du2 +du3 +du4 +duan1 +duan2 +duan3 +duan4 +dui1 +dui4 +dun1 +dun3 +dun4 +duo1 +duo2 +duo3 +duo4 +e +e1 +e2 +e3 +e4 +ei2 +en1 +en4 +er +er2 +er3 +er4 +f +fa1 +fa2 +fa3 +fa4 +fan1 +fan2 +fan3 +fan4 +fang1 +fang2 +fang3 +fang4 +fei1 +fei2 +fei3 +fei4 +fen1 +fen2 +fen3 +fen4 +feng1 +feng2 +feng3 +feng4 +fo2 +fou2 +fou3 +fu1 +fu2 +fu3 +fu4 +g +ga1 +ga2 +ga3 +ga4 +gai1 +gai2 +gai3 +gai4 +gan1 +gan2 +gan3 +gan4 +gang1 +gang2 +gang3 +gang4 +gao1 +gao2 +gao3 +gao4 +ge1 +ge2 +ge3 +ge4 +gei2 +gei3 +gen1 +gen2 +gen3 +gen4 +geng1 +geng3 +geng4 +gong1 +gong3 +gong4 +gou1 +gou2 +gou3 +gou4 +gu +gu1 +gu2 +gu3 +gu4 +gua1 +gua2 +gua3 +gua4 +guai1 +guai2 +guai3 +guai4 +guan1 +guan2 +guan3 +guan4 +guang1 +guang2 +guang3 +guang4 +gui1 +gui2 +gui3 +gui4 +gun3 +gun4 +guo1 +guo2 +guo3 +guo4 +h +ha1 +ha2 +ha3 +hai1 +hai2 +hai3 +hai4 +han1 +han2 +han3 +han4 +hang1 +hang2 +hang4 +hao1 +hao2 +hao3 +hao4 +he1 +he2 +he4 +hei1 +hen2 +hen3 +hen4 +heng1 +heng2 +heng4 +hong1 +hong2 +hong3 +hong4 +hou1 +hou2 +hou3 +hou4 +hu1 +hu2 +hu3 +hu4 +hua1 +hua2 +hua4 +huai2 +huai4 +huan1 +huan2 +huan3 +huan4 +huang1 +huang2 +huang3 +huang4 +hui1 +hui2 +hui3 +hui4 +hun1 +hun2 +hun4 +huo +huo1 +huo2 +huo3 +huo4 +i +j +ji1 +ji2 +ji3 +ji4 +jia +jia1 +jia2 +jia3 +jia4 +jian1 +jian2 +jian3 +jian4 +jiang1 +jiang2 +jiang3 +jiang4 +jiao1 +jiao2 +jiao3 +jiao4 +jie1 +jie2 +jie3 +jie4 +jin1 +jin2 +jin3 +jin4 +jing1 +jing2 +jing3 +jing4 +jiong3 +jiu1 +jiu2 +jiu3 +jiu4 +ju1 +ju2 +ju3 +ju4 +juan1 +juan2 +juan3 +juan4 +jue1 +jue2 +jue4 +jun1 +jun4 +k +ka1 +ka2 +ka3 +kai1 +kai2 +kai3 +kai4 +kan1 +kan2 +kan3 +kan4 +kang1 +kang2 +kang4 +kao1 +kao2 +kao3 +kao4 +ke1 +ke2 +ke3 +ke4 +ken3 +keng1 +kong1 +kong3 +kong4 +kou1 +kou2 +kou3 +kou4 +ku1 +ku2 +ku3 +ku4 +kua1 +kua3 +kua4 +kuai3 +kuai4 +kuan1 +kuan2 +kuan3 +kuang1 +kuang2 +kuang4 +kui1 +kui2 +kui3 +kui4 +kun1 +kun3 +kun4 +kuo4 +l +la +la1 +la2 +la3 +la4 +lai2 +lai4 +lan2 +lan3 +lan4 +lang1 +lang2 +lang3 +lang4 +lao1 +lao2 +lao3 +lao4 +le +le1 +le4 +lei +lei1 +lei2 +lei3 +lei4 +leng1 +leng2 +leng3 +leng4 +li +li1 +li2 +li3 +li4 +lia3 +lian2 +lian3 +lian4 +liang2 +liang3 +liang4 +liao1 +liao2 +liao3 +liao4 +lie1 +lie2 +lie3 +lie4 +lin1 +lin2 +lin3 +lin4 +ling2 +ling3 +ling4 +liu1 +liu2 +liu3 +liu4 +long1 +long2 +long3 +long4 +lou1 +lou2 +lou3 +lou4 +lu1 +lu2 +lu3 +lu4 +luan2 +luan3 +luan4 +lun1 +lun2 +lun4 +luo1 +luo2 +luo3 +luo4 +lv2 +lv3 +lv4 +lve3 +lve4 +m +ma +ma1 +ma2 +ma3 +ma4 +mai2 +mai3 +mai4 +man1 +man2 +man3 +man4 +mang2 +mang3 +mao1 +mao2 +mao3 +mao4 +me +mei2 +mei3 +mei4 +men +men1 +men2 +men4 +meng +meng1 +meng2 +meng3 +meng4 +mi1 +mi2 +mi3 +mi4 +mian2 +mian3 +mian4 +miao1 +miao2 +miao3 +miao4 +mie1 +mie4 +min2 +min3 +ming2 +ming3 +ming4 +miu4 +mo1 +mo2 +mo3 +mo4 +mou1 +mou2 +mou3 +mu2 +mu3 +mu4 +n +n2 +na1 +na2 +na3 +na4 +nai2 +nai3 +nai4 +nan1 +nan2 +nan3 +nan4 +nang1 +nang2 +nang3 +nao1 +nao2 +nao3 +nao4 +ne +ne2 +ne4 +nei3 +nei4 +nen4 +neng2 +ni1 +ni2 +ni3 +ni4 +nian1 +nian2 +nian3 +nian4 +niang2 +niang4 +niao2 +niao3 +niao4 +nie1 +nie4 +nin2 +ning2 +ning3 +ning4 +niu1 +niu2 +niu3 +niu4 +nong2 +nong4 +nou4 +nu2 +nu3 +nu4 +nuan3 +nuo2 +nuo4 +nv2 +nv3 +nve4 +o +o1 +o2 +ou1 +ou2 +ou3 +ou4 +p +pa1 +pa2 +pa4 +pai1 +pai2 +pai3 +pai4 +pan1 +pan2 +pan4 +pang1 +pang2 +pang4 +pao1 +pao2 +pao3 +pao4 +pei1 +pei2 +pei4 +pen1 +pen2 +pen4 +peng1 +peng2 +peng3 +peng4 +pi1 +pi2 +pi3 +pi4 +pian1 +pian2 +pian4 +piao1 +piao2 +piao3 +piao4 +pie1 +pie2 +pie3 +pin1 +pin2 +pin3 +pin4 +ping1 +ping2 +po1 +po2 +po3 +po4 +pou1 +pu1 +pu2 +pu3 +pu4 +q +qi1 +qi2 +qi3 +qi4 +qia1 +qia3 +qia4 +qian1 +qian2 +qian3 +qian4 +qiang1 +qiang2 +qiang3 +qiang4 +qiao1 +qiao2 +qiao3 +qiao4 +qie1 +qie2 +qie3 +qie4 +qin1 +qin2 +qin3 +qin4 +qing1 +qing2 +qing3 +qing4 +qiong1 +qiong2 +qiu1 +qiu2 +qiu3 +qu1 +qu2 +qu3 +qu4 +quan1 +quan2 +quan3 +quan4 +que1 +que2 +que4 +qun2 +r +ran2 +ran3 +rang1 +rang2 +rang3 +rang4 +rao2 +rao3 +rao4 +re2 +re3 +re4 +ren2 +ren3 +ren4 +reng1 +reng2 +ri4 +rong1 +rong2 +rong3 +rou2 +rou4 +ru2 +ru3 +ru4 +ruan2 +ruan3 +rui3 +rui4 +run4 +ruo4 +s +sa1 +sa2 +sa3 +sa4 +sai1 +sai4 +san1 +san2 +san3 +san4 +sang1 +sang3 +sang4 +sao1 +sao2 +sao3 +sao4 +se4 +sen1 +seng1 +sha1 +sha2 +sha3 +sha4 +shai1 +shai2 +shai3 +shai4 +shan1 +shan3 +shan4 +shang +shang1 +shang3 +shang4 +shao1 +shao2 +shao3 +shao4 +she1 +she2 +she3 +she4 +shei2 +shen1 +shen2 +shen3 +shen4 +sheng1 +sheng2 +sheng3 +sheng4 +shi +shi1 +shi2 +shi3 +shi4 +shou1 +shou2 +shou3 +shou4 +shu1 +shu2 +shu3 +shu4 +shua1 +shua2 +shua3 +shua4 +shuai1 +shuai3 +shuai4 +shuan1 +shuan4 +shuang1 +shuang3 +shui2 +shui3 +shui4 +shun3 +shun4 +shuo1 +shuo4 +si1 +si2 +si3 +si4 +song1 +song3 +song4 +sou1 +sou3 +sou4 +su1 +su2 +su4 +suan1 +suan4 +sui1 +sui2 +sui3 +sui4 +sun1 +sun3 +suo +suo1 +suo2 +suo3 +t +ta1 +ta2 +ta3 +ta4 +tai1 +tai2 +tai4 +tan1 +tan2 +tan3 +tan4 +tang1 +tang2 +tang3 +tang4 +tao1 +tao2 +tao3 +tao4 +te4 +teng2 +ti1 +ti2 +ti3 +ti4 +tian1 +tian2 +tian3 +tiao1 +tiao2 +tiao3 +tiao4 +tie1 +tie2 +tie3 +tie4 +ting1 +ting2 +ting3 +tong1 +tong2 +tong3 +tong4 +tou +tou1 +tou2 +tou4 +tu1 +tu2 +tu3 +tu4 +tuan1 +tuan2 +tui1 +tui2 +tui3 +tui4 +tun1 +tun2 +tun4 +tuo1 +tuo2 +tuo3 +tuo4 +u +v +w +wa +wa1 +wa2 +wa3 +wa4 +wai1 +wai3 +wai4 +wan1 +wan2 +wan3 +wan4 +wang1 +wang2 +wang3 +wang4 +wei1 +wei2 +wei3 +wei4 +wen1 +wen2 +wen3 +wen4 +weng1 +weng4 +wo1 +wo2 +wo3 +wo4 +wu1 +wu2 +wu3 +wu4 +x +xi1 +xi2 +xi3 +xi4 +xia1 +xia2 +xia4 +xian1 +xian2 +xian3 +xian4 +xiang1 +xiang2 +xiang3 +xiang4 +xiao1 +xiao2 +xiao3 +xiao4 +xie1 +xie2 +xie3 +xie4 +xin1 +xin2 +xin4 +xing1 +xing2 +xing3 +xing4 +xiong1 +xiong2 +xiu1 +xiu3 +xiu4 +xu +xu1 +xu2 +xu3 +xu4 +xuan1 +xuan2 +xuan3 +xuan4 +xue1 +xue2 +xue3 +xue4 +xun1 +xun2 +xun4 +y +ya +ya1 +ya2 +ya3 +ya4 +yan1 +yan2 +yan3 +yan4 +yang1 +yang2 +yang3 +yang4 +yao1 +yao2 +yao3 +yao4 +ye1 +ye2 +ye3 +ye4 +yi +yi1 +yi2 +yi3 +yi4 +yin1 +yin2 +yin3 +yin4 +ying1 +ying2 +ying3 +ying4 +yo1 +yong1 +yong2 +yong3 +yong4 +you1 +you2 +you3 +you4 +yu1 +yu2 +yu3 +yu4 +yuan1 +yuan2 +yuan3 +yuan4 +yue1 +yue4 +yun1 +yun2 +yun3 +yun4 +z +za1 +za2 +za3 +zai1 +zai3 +zai4 +zan1 +zan2 +zan3 +zan4 +zang1 +zang4 +zao1 +zao2 +zao3 +zao4 +ze2 +ze4 +zei2 +zen3 +zeng1 +zeng4 +zha1 +zha2 +zha3 +zha4 +zhai1 +zhai2 +zhai3 +zhai4 +zhan1 +zhan2 +zhan3 +zhan4 +zhang1 +zhang2 +zhang3 +zhang4 +zhao1 +zhao2 +zhao3 +zhao4 +zhe +zhe1 +zhe2 +zhe3 +zhe4 +zhen1 +zhen2 +zhen3 +zhen4 +zheng1 +zheng2 +zheng3 +zheng4 +zhi1 +zhi2 +zhi3 +zhi4 +zhong1 +zhong2 +zhong3 +zhong4 +zhou1 +zhou2 +zhou3 +zhou4 +zhu1 +zhu2 +zhu3 +zhu4 +zhua1 +zhua2 +zhua3 +zhuai1 +zhuai3 +zhuai4 +zhuan1 +zhuan2 +zhuan3 +zhuan4 +zhuang1 +zhuang4 +zhui1 +zhui4 +zhun1 +zhun2 +zhun3 +zhuo1 +zhuo2 +zi +zi1 +zi2 +zi3 +zi4 +zong1 +zong2 +zong3 +zong4 +zou1 +zou2 +zou3 +zou4 +zu1 +zu2 +zu3 +zuan1 +zuan3 +zuan4 +zui2 +zui3 +zui4 +zun1 +zuo +zuo1 +zuo2 +zuo3 +zuo4 +{ +~ +¡ +¢ +£ +¥ +§ +¨ +© +« +® +¯ +° +± +² +³ +´ +µ +· +¹ +º +» +¼ +½ +¾ +¿ +À +Á + +à +Ä +Å +Æ +Ç +È +É +Ê +Í +Î +Ñ +Ó +Ö +× +Ø +Ú +Ü +Ý +Þ +ß +à +á +â +ã +ä +å +æ +ç +è +é +ê +ë +ì +í +î +ï +ð +ñ +ò +ó +ô +õ +ö +ø +ù +ú +û +ü +ý +Ā +ā +ă +ą +ć +Č +č +Đ +đ +ē +ė +ę +ě +ĝ +ğ +ħ +ī +į +İ +ı +Ł +ł +ń +ņ +ň +ŋ +Ō +ō +ő +œ +ř +Ś +ś +Ş +ş +Š +š +Ť +ť +ũ +ū +ź +Ż +ż +Ž +ž +ơ +ư +ǎ +ǐ +ǒ +ǔ +ǚ +ș +ț +ɑ +ɔ +ɕ +ə +ɛ +ɜ +ɡ +ɣ +ɪ +ɫ +ɴ +ɹ +ɾ +ʃ +ʊ +ʌ +ʒ +ʔ +ʰ +ʷ +ʻ +ʾ +ʿ +ˈ +ː +˙ +˜ +ˢ +́ +̅ +Α +Β +Δ +Ε +Θ +Κ +Λ +Μ +Ξ +Π +Σ +Τ +Φ +Χ +Ψ +Ω +ά +έ +ή +ί +α +β +γ +δ +ε +ζ +η +θ +ι +κ +λ +μ +ν +ξ +ο +π +ρ +ς +σ +τ +υ +φ +χ +ψ +ω +ϊ +ό +ύ +ώ +ϕ +ϵ +Ё +А +Б +В +Г +Д +Е +Ж +З +И +Й +К +Л +М +Н +О +П +Р +С +Т +У +Ф +Х +Ц +Ч +Ш +Щ +Ы +Ь +Э +Ю +Я +а +б +в +г +д +е +ж +з +и +й +к +л +м +н +о +п +р +с +т +у +ф +х +ц +ч +ш +щ +ъ +ы +ь +э +ю +я +ё +і +ְ +ִ +ֵ +ֶ +ַ +ָ +ֹ +ּ +־ +ׁ +א +ב +ג +ד +ה +ו +ז +ח +ט +י +כ +ל +ם +מ +ן +נ +ס +ע +פ +ק +ר +ש +ת +أ +ب +ة +ت +ج +ح +د +ر +ز +س +ص +ط +ع +ق +ك +ل +م +ن +ه +و +ي +َ +ُ +ِ +ْ +ก +ข +ง +จ +ต +ท +น +ป +ย +ร +ว +ส +ห +อ +ฮ +ั +า +ี +ึ +โ +ใ +ไ +่ +้ +์ +ḍ +Ḥ +ḥ +ṁ +ṃ +ṅ +ṇ +Ṛ +ṛ +Ṣ +ṣ +Ṭ +ṭ +ạ +ả +Ấ +ấ +ầ +ậ +ắ +ằ +ẻ +ẽ +ế +ề +ể +ễ +ệ +ị +ọ +ỏ +ố +ồ +ộ +ớ +ờ +ở +ụ +ủ +ứ +ữ +ἀ +ἁ +Ἀ +ἐ +ἔ +ἰ +ἱ +ὀ +ὁ +ὐ +ὲ +ὸ +ᾶ +᾽ +ῆ +ῇ +ῶ +‎ +‑ +‒ +– +— +― +‖ +† +‡ +• +… +‧ +‬ +′ +″ +⁄ +⁡ +⁰ +⁴ +⁵ +⁶ +⁷ +⁸ +⁹ +₁ +₂ +₃ +€ +₱ +₹ +₽ +℃ +ℏ +ℓ +№ +ℝ +™ +⅓ +⅔ +⅛ +→ +∂ +∈ +∑ +− +∗ +√ +∞ +∫ +≈ +≠ +≡ +≤ +≥ +⋅ +⋯ +█ +♪ +⟨ +⟩ +、 +。 +《 +》 +「 +」 +【 +】 +あ +う +え +お +か +が +き +ぎ +く +ぐ +け +げ +こ +ご +さ +し +じ +す +ず +せ +ぜ +そ +ぞ +た +だ +ち +っ +つ +で +と +ど +な +に +ね +の +は +ば +ひ +ぶ +へ +べ +ま +み +む +め +も +ゃ +や +ゆ +ょ +よ +ら +り +る +れ +ろ +わ +を +ん +ァ +ア +ィ +イ +ウ +ェ +エ +オ +カ +ガ +キ +ク +ケ +ゲ +コ +ゴ +サ +ザ +シ +ジ +ス +ズ +セ +ゾ +タ +ダ +チ +ッ +ツ +テ +デ +ト +ド +ナ +ニ +ネ +ノ +バ +パ +ビ +ピ +フ +プ +ヘ +ベ +ペ +ホ +ボ +ポ +マ +ミ +ム +メ +モ +ャ +ヤ +ュ +ユ +ョ +ヨ +ラ +リ +ル +レ +ロ +ワ +ン +・ +ー +ㄋ +ㄍ +ㄎ +ㄏ +ㄓ +ㄕ +ㄚ +ㄜ +ㄟ +ㄤ +ㄥ +ㄧ +ㄱ +ㄴ +ㄷ +ㄹ +ㅁ +ㅂ +ㅅ +ㅈ +ㅍ +ㅎ +ㅏ +ㅓ +ㅗ +ㅜ +ㅡ +ㅣ +㗎 +가 +각 +간 +갈 +감 +갑 +갓 +갔 +강 +같 +개 +거 +건 +걸 +겁 +것 +겉 +게 +겠 +겨 +결 +겼 +경 +계 +고 +곤 +골 +곱 +공 +과 +관 +광 +교 +구 +국 +굴 +귀 +귄 +그 +근 +글 +금 +기 +긴 +길 +까 +깍 +깔 +깜 +깨 +께 +꼬 +꼭 +꽃 +꾸 +꿔 +끔 +끗 +끝 +끼 +나 +난 +날 +남 +납 +내 +냐 +냥 +너 +넘 +넣 +네 +녁 +년 +녕 +노 +녹 +놀 +누 +눈 +느 +는 +늘 +니 +님 +닙 +다 +닥 +단 +달 +닭 +당 +대 +더 +덕 +던 +덥 +데 +도 +독 +동 +돼 +됐 +되 +된 +될 +두 +둑 +둥 +드 +들 +등 +디 +따 +딱 +딸 +땅 +때 +떤 +떨 +떻 +또 +똑 +뚱 +뛰 +뜻 +띠 +라 +락 +란 +람 +랍 +랑 +래 +랜 +러 +런 +럼 +렇 +레 +려 +력 +렵 +렸 +로 +록 +롬 +루 +르 +른 +를 +름 +릉 +리 +릴 +림 +마 +막 +만 +많 +말 +맑 +맙 +맛 +매 +머 +먹 +멍 +메 +면 +명 +몇 +모 +목 +몸 +못 +무 +문 +물 +뭐 +뭘 +미 +민 +밌 +밑 +바 +박 +밖 +반 +받 +발 +밤 +밥 +방 +배 +백 +밸 +뱀 +버 +번 +벌 +벚 +베 +벼 +벽 +별 +병 +보 +복 +본 +볼 +봐 +봤 +부 +분 +불 +비 +빔 +빛 +빠 +빨 +뼈 +뽀 +뿅 +쁘 +사 +산 +살 +삼 +샀 +상 +새 +색 +생 +서 +선 +설 +섭 +섰 +성 +세 +셔 +션 +셨 +소 +속 +손 +송 +수 +숙 +순 +술 +숫 +숭 +숲 +쉬 +쉽 +스 +슨 +습 +슷 +시 +식 +신 +실 +싫 +심 +십 +싶 +싸 +써 +쓰 +쓴 +씌 +씨 +씩 +씬 +아 +악 +안 +않 +알 +야 +약 +얀 +양 +얘 +어 +언 +얼 +엄 +업 +없 +었 +엉 +에 +여 +역 +연 +염 +엽 +영 +옆 +예 +옛 +오 +온 +올 +옷 +옹 +와 +왔 +왜 +요 +욕 +용 +우 +운 +울 +웃 +워 +원 +월 +웠 +위 +윙 +유 +육 +윤 +으 +은 +을 +음 +응 +의 +이 +익 +인 +일 +읽 +임 +입 +있 +자 +작 +잔 +잖 +잘 +잡 +잤 +장 +재 +저 +전 +점 +정 +제 +져 +졌 +조 +족 +좀 +종 +좋 +죠 +주 +준 +줄 +중 +줘 +즈 +즐 +즘 +지 +진 +집 +짜 +짝 +쩌 +쪼 +쪽 +쫌 +쭈 +쯔 +찌 +찍 +차 +착 +찾 +책 +처 +천 +철 +체 +쳐 +쳤 +초 +촌 +추 +출 +춤 +춥 +춰 +치 +친 +칠 +침 +칩 +칼 +커 +켓 +코 +콩 +쿠 +퀴 +크 +큰 +큽 +키 +킨 +타 +태 +터 +턴 +털 +테 +토 +통 +투 +트 +특 +튼 +틀 +티 +팀 +파 +팔 +패 +페 +펜 +펭 +평 +포 +폭 +표 +품 +풍 +프 +플 +피 +필 +하 +학 +한 +할 +함 +합 +항 +해 +햇 +했 +행 +허 +험 +형 +혜 +호 +혼 +홀 +화 +회 +획 +후 +휴 +흐 +흔 +희 +히 +힘 +ﷺ +ﷻ +! +, +? +� +𠮶 diff --git a/egs/wenetspeech4tts/TTS/local/audio.py b/egs/wenetspeech4tts/TTS/local/audio.py new file mode 100644 index 0000000000..b643e3de0f --- /dev/null +++ b/egs/wenetspeech4tts/TTS/local/audio.py @@ -0,0 +1,122 @@ +# Copyright (c) 2024 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. + +import math +import os +import pathlib +import random +from typing import List, Optional, Tuple + +import librosa +import numpy as np +import torch +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn +from tqdm import tqdm + +# from env import AttrDict + +MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases) + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + return dynamic_range_compression_torch(magnitudes) + + +def spectral_de_normalize_torch(magnitudes): + return dynamic_range_decompression_torch(magnitudes) + + +mel_basis_cache = {} +hann_window_cache = {} + + +def mel_spectrogram( + y: torch.Tensor, + n_fft: int = 1024, + num_mels: int = 100, + sampling_rate: int = 24_000, + hop_size: int = 256, + win_size: int = 1024, + fmin: int = 0, + fmax: int = None, + center: bool = False, +) -> torch.Tensor: + """ + Calculate the mel spectrogram of an input signal. + This function uses slaney norm for the librosa mel filterbank (using librosa.filters.mel) and uses Hann window for STFT (using torch.stft). + + Args: + y (torch.Tensor): Input signal. + n_fft (int): FFT size. + num_mels (int): Number of mel bins. + sampling_rate (int): Sampling rate of the input signal. + hop_size (int): Hop size for STFT. + win_size (int): Window size for STFT. + fmin (int): Minimum frequency for mel filterbank. + fmax (int): Maximum frequency for mel filterbank. If None, defaults to half the sampling rate (fmax = sr / 2.0) inside librosa_mel_fn + center (bool): Whether to pad the input to center the frames. Default is False. + + Returns: + torch.Tensor: Mel spectrogram. + """ + if torch.min(y) < -1.0: + print(f"[WARNING] Min value of input waveform signal is {torch.min(y)}") + if torch.max(y) > 1.0: + print(f"[WARNING] Max value of input waveform signal is {torch.max(y)}") + + device = y.device + key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}" + + if key not in mel_basis_cache: + mel = librosa_mel_fn( + sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax + ) + mel_basis_cache[key] = torch.from_numpy(mel).float().to(device) + hann_window_cache[key] = torch.hann_window(win_size).to(device) + + mel_basis = mel_basis_cache[key] + hann_window = hann_window_cache[key] + + padding = (n_fft - hop_size) // 2 + y = torch.nn.functional.pad( + y.unsqueeze(1), (padding, padding), mode="reflect" + ).squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window, + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9) + + mel_spec = torch.matmul(mel_basis, spec) + mel_spec = spectral_normalize_torch(mel_spec) + + return mel_spec diff --git a/egs/wenetspeech4tts/TTS/local/compute_mel_feat.py b/egs/wenetspeech4tts/TTS/local/compute_mel_feat.py new file mode 100755 index 0000000000..5292c75add --- /dev/null +++ b/egs/wenetspeech4tts/TTS/local/compute_mel_feat.py @@ -0,0 +1,218 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the LJSpeech dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path + +import torch +from fbank import MatchaFbank, MatchaFbankConfig +from lhotse import CutSet, LilcomChunkyWriter +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--num-jobs", + type=int, + default=1, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--src-dir", + type=Path, + default=Path("data/manifests"), + help="Path to the manifest files", + ) + + parser.add_argument( + "--output-dir", + type=Path, + default=Path("data/fbank"), + help="Path to the tokenized files", + ) + + parser.add_argument( + "--dataset-parts", + type=str, + default="Basic", + help="Space separated dataset parts", + ) + + parser.add_argument( + "--prefix", + type=str, + default="wenetspeech4tts", + help="prefix of the manifest file", + ) + + parser.add_argument( + "--suffix", + type=str, + default="jsonl.gz", + help="suffix of the manifest file", + ) + + parser.add_argument( + "--split", + type=int, + default=100, + help="Split the cut_set into multiple parts", + ) + + parser.add_argument( + "--resample-to-24kHz", + default=True, + help="Resample the audio to 24kHz", + ) + + parser.add_argument( + "--extractor", + type=str, + choices=["bigvgan", "hifigan"], + default="bigvgan", + help="The type of extractor to use", + ) + return parser + + +def compute_fbank(args): + src_dir = Path(args.src_dir) + output_dir = Path(args.output_dir) + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + + num_jobs = min(args.num_jobs, os.cpu_count()) + dataset_parts = args.dataset_parts.replace("--dataset-parts", "").strip().split(" ") + + logging.info(f"num_jobs: {num_jobs}") + logging.info(f"src_dir: {src_dir}") + logging.info(f"output_dir: {output_dir}") + logging.info(f"dataset_parts: {dataset_parts}") + if args.extractor == "bigvgan": + config = MatchaFbankConfig( + n_fft=1024, + n_mels=100, + sampling_rate=24_000, + hop_length=256, + win_length=1024, + f_min=0, + f_max=None, + ) + elif args.extractor == "hifigan": + config = MatchaFbankConfig( + n_fft=1024, + n_mels=80, + sampling_rate=22050, + hop_length=256, + win_length=1024, + f_min=0, + f_max=8000, + ) + else: + raise NotImplementedError(f"Extractor {args.extractor} is not implemented") + + extractor = MatchaFbank(config) + + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=args.src_dir, + prefix=args.prefix, + suffix=args.suffix, + types=["recordings", "supervisions", "cuts"], + ) + + with get_executor() as ex: + for partition, m in manifests.items(): + logging.info( + f"Processing partition: {partition} CUDA: {torch.cuda.is_available()}" + ) + try: + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ) + except Exception: + cut_set = m["cuts"] + + if args.split > 1: + cut_sets = cut_set.split(args.split) + else: + cut_sets = [cut_set] + + for idx, part in enumerate(cut_sets): + if args.split > 1: + storage_path = f"{args.output_dir}/{args.prefix}_{args.extractor}_{partition}_{idx}" + else: + storage_path = ( + f"{args.output_dir}/{args.prefix}_{args.extractor}_{partition}" + ) + + if args.resample_to_24kHz: + part = part.resample(24000) + + with torch.no_grad(): + part = part.compute_and_store_features( + extractor=extractor, + storage_path=storage_path, + num_jobs=num_jobs if ex is None else 64, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + + if args.split > 1: + cuts_filename = ( + f"{args.prefix}_cuts_{partition}.{idx}.{args.suffix}" + ) + else: + cuts_filename = f"{args.prefix}_cuts_{partition}.{args.suffix}" + + part.to_file(f"{args.output_dir}/{cuts_filename}") + logging.info(f"Saved {cuts_filename}") + + +if __name__ == "__main__": + # Torch's multithreaded behavior needs to be disabled or + # it wastes a lot of CPU and slow things down. + # Do this outside of main() in case it needs to take effect + # even when we are not invoking the main (e.g. when spawning subprocesses). + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + args = get_parser().parse_args() + compute_fbank(args) diff --git a/egs/wenetspeech4tts/TTS/local/compute_wer.sh b/egs/wenetspeech4tts/TTS/local/compute_wer.sh new file mode 100644 index 0000000000..2835463837 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/local/compute_wer.sh @@ -0,0 +1,26 @@ +wav_dir=$1 +wav_files=$(ls $wav_dir/*.wav) +# if wav_files is empty, then exit +if [ -z "$wav_files" ]; then + exit 1 +fi +label_file=$2 +model_path=local/sherpa-onnx-paraformer-zh-2023-09-14 + +if [ ! -d $model_path ]; then + pip install sherpa-onnx + wget -nc https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 + tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 -C local +fi + +python3 local/offline-decode-files.py \ + --tokens=$model_path/tokens.txt \ + --paraformer=$model_path/model.int8.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + --debug=false \ + --sample-rate=24000 \ + --log-dir $wav_dir \ + --feature-dim=80 \ + --label $label_file \ + $wav_files diff --git a/egs/wenetspeech4tts/TTS/local/fbank.py b/egs/wenetspeech4tts/TTS/local/fbank.py new file mode 120000 index 0000000000..3cfb7fe3f4 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/local/fbank.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/matcha/fbank.py \ No newline at end of file diff --git a/egs/wenetspeech4tts/TTS/local/offline-decode-files.py b/egs/wenetspeech4tts/TTS/local/offline-decode-files.py new file mode 100755 index 0000000000..fa6cbdb3eb --- /dev/null +++ b/egs/wenetspeech4tts/TTS/local/offline-decode-files.py @@ -0,0 +1,495 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2023 by manyeyes +# Copyright (c) 2023 Xiaomi Corporation + +""" +This file demonstrates how to use sherpa-onnx Python API to transcribe +file(s) with a non-streaming model. + +(1) For paraformer + + ./python-api-examples/offline-decode-files.py \ + --tokens=/path/to/tokens.txt \ + --paraformer=/path/to/paraformer.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + --debug=false \ + --sample-rate=16000 \ + --feature-dim=80 \ + /path/to/0.wav \ + /path/to/1.wav + +(2) For transducer models from icefall + + ./python-api-examples/offline-decode-files.py \ + --tokens=/path/to/tokens.txt \ + --encoder=/path/to/encoder.onnx \ + --decoder=/path/to/decoder.onnx \ + --joiner=/path/to/joiner.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + --debug=false \ + --sample-rate=16000 \ + --feature-dim=80 \ + /path/to/0.wav \ + /path/to/1.wav + +(3) For CTC models from NeMo + +python3 ./python-api-examples/offline-decode-files.py \ + --tokens=./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt \ + --nemo-ctc=./sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + --debug=false \ + ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav \ + ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav \ + ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav + +(4) For Whisper models + +python3 ./python-api-examples/offline-decode-files.py \ + --whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \ + --whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \ + --tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \ + --whisper-task=transcribe \ + --num-threads=1 \ + ./sherpa-onnx-whisper-base.en/test_wavs/0.wav \ + ./sherpa-onnx-whisper-base.en/test_wavs/1.wav \ + ./sherpa-onnx-whisper-base.en/test_wavs/8k.wav + +(5) For CTC models from WeNet + +python3 ./python-api-examples/offline-decode-files.py \ + --wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \ + --tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \ + ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/0.wav \ + ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/1.wav \ + ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav + +(6) For tdnn models of the yesno recipe from icefall + +python3 ./python-api-examples/offline-decode-files.py \ + --sample-rate=8000 \ + --feature-dim=23 \ + --tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \ + --tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \ + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \ + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \ + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav + +Please refer to +https://k2-fsa.github.io/sherpa/onnx/index.html +to install sherpa-onnx and to download non-streaming pre-trained models +used in this file. +""" +import argparse +import time +import wave +from pathlib import Path +from typing import List, Tuple + +import numpy as np +import sherpa_onnx +import soundfile as sf + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--tokens", + type=str, + help="Path to tokens.txt", + ) + + parser.add_argument( + "--hotwords-file", + type=str, + default="", + help=""" + The file containing hotwords, one words/phrases per line, like + HELLO WORLD + 你好世界 + """, + ) + + parser.add_argument( + "--hotwords-score", + type=float, + default=1.5, + help=""" + The hotword score of each token for biasing word/phrase. Used only if + --hotwords-file is given. + """, + ) + + parser.add_argument( + "--modeling-unit", + type=str, + default="", + help=""" + The modeling unit of the model, valid values are cjkchar, bpe, cjkchar+bpe. + Used only when hotwords-file is given. + """, + ) + + parser.add_argument( + "--bpe-vocab", + type=str, + default="", + help=""" + The path to the bpe vocabulary, the bpe vocabulary is generated by + sentencepiece, you can also export the bpe vocabulary through a bpe model + by `scripts/export_bpe_vocab.py`. Used only when hotwords-file is given + and modeling-unit is bpe or cjkchar+bpe. + """, + ) + + parser.add_argument( + "--encoder", + default="", + type=str, + help="Path to the encoder model", + ) + + parser.add_argument( + "--decoder", + default="", + type=str, + help="Path to the decoder model", + ) + + parser.add_argument( + "--joiner", + default="", + type=str, + help="Path to the joiner model", + ) + + parser.add_argument( + "--paraformer", + default="", + type=str, + help="Path to the model.onnx from Paraformer", + ) + + parser.add_argument( + "--nemo-ctc", + default="", + type=str, + help="Path to the model.onnx from NeMo CTC", + ) + + parser.add_argument( + "--wenet-ctc", + default="", + type=str, + help="Path to the model.onnx from WeNet CTC", + ) + + parser.add_argument( + "--tdnn-model", + default="", + type=str, + help="Path to the model.onnx for the tdnn model of the yesno recipe", + ) + + parser.add_argument( + "--num-threads", + type=int, + default=1, + help="Number of threads for neural network computation", + ) + + parser.add_argument( + "--whisper-encoder", + default="", + type=str, + help="Path to whisper encoder model", + ) + + parser.add_argument( + "--whisper-decoder", + default="", + type=str, + help="Path to whisper decoder model", + ) + + parser.add_argument( + "--whisper-language", + default="", + type=str, + help="""It specifies the spoken language in the input audio file. + Example values: en, fr, de, zh, jp. + Available languages for multilingual models can be found at + https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10 + If not specified, we infer the language from the input audio file. + """, + ) + + parser.add_argument( + "--whisper-task", + default="transcribe", + choices=["transcribe", "translate"], + type=str, + help="""For multilingual models, if you specify translate, the output + will be in English. + """, + ) + + parser.add_argument( + "--whisper-tail-paddings", + default=-1, + type=int, + help="""Number of tail padding frames. + We have removed the 30-second constraint from whisper, so you need to + choose the amount of tail padding frames by yourself. + Use -1 to use a default value for tail padding. + """, + ) + + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="Valid values are greedy_search and modified_beam_search", + ) + parser.add_argument( + "--debug", + type=bool, + default=False, + help="True to show debug messages", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="""Sample rate of the feature extractor. Must match the one + expected by the model. Note: The input sound files can have a + different sample rate from this argument.""", + ) + + parser.add_argument( + "--feature-dim", + type=int, + default=80, + help="Feature dimension. Must match the one expected by the model", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to decode. Each file must be of WAVE" + "format with a single channel, and each sample has 16-bit, " + "i.e., int16_t. " + "The sample rate of the file can be arbitrary and does not need to " + "be 16 kHz", + ) + + parser.add_argument( + "--name", + type=str, + default="", + help="The directory containing the input sound files to decode", + ) + + parser.add_argument( + "--log-dir", + type=str, + default="", + help="The directory containing the input sound files to decode", + ) + + parser.add_argument( + "--label", + type=str, + default=None, + help="wav_base_name label", + ) + return parser.parse_args() + + +def assert_file_exists(filename: str): + assert Path(filename).is_file(), ( + f"{filename} does not exist!\n" + "Please refer to " + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" + ) + + +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: + """ + Args: + wave_filename: + Path to a wave file. It should be single channel and can be of type + 32-bit floating point PCM. Its sample rate does not need to be 24kHz. + + Returns: + Return a tuple containing: + - A 1-D array of dtype np.float32 containing the samples, + which are normalized to the range [-1, 1]. + - Sample rate of the wave file. + """ + + samples, sample_rate = sf.read(wave_filename, dtype="float32") + assert ( + samples.ndim == 1 + ), f"Expected single channel, but got {samples.ndim} channels." + + samples_float32 = samples.astype(np.float32) + + return samples_float32, sample_rate + + +def normalize_text_alimeeting(text: str) -> str: + """ + Text normalization similar to M2MeT challenge baseline. + See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl + """ + import re + + text = text.replace(" ", "") + text = text.replace("", "") + text = text.replace("<%>", "") + text = text.replace("<->", "") + text = text.replace("<$>", "") + text = text.replace("<#>", "") + text = text.replace("<_>", "") + text = text.replace("", "") + text = text.replace("`", "") + text = text.replace("&", "") + text = text.replace(",", "") + if re.search("[a-zA-Z]", text): + text = text.upper() + text = text.replace("A", "A") + text = text.replace("a", "A") + text = text.replace("b", "B") + text = text.replace("c", "C") + text = text.replace("k", "K") + text = text.replace("t", "T") + text = text.replace(",", "") + text = text.replace("丶", "") + text = text.replace("。", "") + text = text.replace("、", "") + text = text.replace("?", "") + return text + + +def main(): + args = get_args() + assert_file_exists(args.tokens) + assert args.num_threads > 0, args.num_threads + + assert len(args.nemo_ctc) == 0, args.nemo_ctc + assert len(args.wenet_ctc) == 0, args.wenet_ctc + assert len(args.whisper_encoder) == 0, args.whisper_encoder + assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert len(args.tdnn_model) == 0, args.tdnn_model + + assert_file_exists(args.paraformer) + + recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer( + paraformer=args.paraformer, + tokens=args.tokens, + num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feature_dim, + decoding_method=args.decoding_method, + debug=args.debug, + ) + + print("Started!") + start_time = time.time() + + streams, results = [], [] + total_duration = 0 + + for i, wave_filename in enumerate(args.sound_files): + assert_file_exists(wave_filename) + samples, sample_rate = read_wave(wave_filename) + duration = len(samples) / sample_rate + total_duration += duration + s = recognizer.create_stream() + s.accept_waveform(sample_rate, samples) + + streams.append(s) + if i % 10 == 0: + recognizer.decode_streams(streams) + results += [s.result.text for s in streams] + streams = [] + print(f"Processed {i} files") + # process the last batch + if streams: + recognizer.decode_streams(streams) + results += [s.result.text for s in streams] + end_time = time.time() + print("Done!") + + results_dict = {} + for wave_filename, result in zip(args.sound_files, results): + print(f"{wave_filename}\n{result}") + print("-" * 10) + wave_basename = Path(wave_filename).stem + results_dict[wave_basename] = result + + elapsed_seconds = end_time - start_time + rtf = elapsed_seconds / total_duration + print(f"num_threads: {args.num_threads}") + print(f"decoding_method: {args.decoding_method}") + print(f"Wave duration: {total_duration:.3f} s") + print(f"Elapsed time: {elapsed_seconds:.3f} s") + print( + f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" + ) + if args.label: + from icefall.utils import store_transcripts, write_error_stats + + labels_dict = {} + with open(args.label, "r") as f: + for line in f: + # fields = line.strip().split(" ") + # fields = [item for item in fields if item] + # assert len(fields) == 4 + # prompt_text, prompt_audio, text, audio_path = fields + + fields = line.strip().split("|") + fields = [item for item in fields if item] + assert len(fields) == 4 + audio_path, prompt_text, prompt_audio, text = fields + labels_dict[Path(audio_path).stem] = normalize_text_alimeeting(text) + + final_results = [] + for key, value in results_dict.items(): + final_results.append((key, labels_dict[key], value)) + + store_transcripts( + filename=f"{args.log_dir}/recogs-{args.name}.txt", texts=final_results + ) + with open(f"{args.log_dir}/errs-{args.name}.txt", "w") as f: + write_error_stats(f, "test-set", final_results, enable_log=True) + + with open(f"{args.log_dir}/errs-{args.name}.txt", "r") as f: + print(f.readline()) # WER + print(f.readline()) # Detailed errors + + +if __name__ == "__main__": + main() diff --git a/egs/wenetspeech4tts/TTS/prepare.sh b/egs/wenetspeech4tts/TTS/prepare.sh index 54e140dbb1..81efd3a796 100755 --- a/egs/wenetspeech4tts/TTS/prepare.sh +++ b/egs/wenetspeech4tts/TTS/prepare.sh @@ -98,3 +98,44 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then fi python3 ./local/display_manifest_statistics.py --manifest-dir ${audio_feats_dir} fi + +subset="Basic" +prefix="wenetspeech4tts" +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Generate fbank (used by ./f5-tts)" + mkdir -p data/fbank + if [ ! -e data/fbank/.${prefix}.done ]; then + ./local/compute_mel_feat.py --dataset-parts $subset --split 100 + touch data/fbank/.${prefix}.done + fi +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 7: Split the ${prefix} cuts into train, valid and test sets (used by ./f5-tts)" + if [ ! -f data/fbank/${prefix}_cuts_${subset}.jsonl.gz ]; then + echo "Combining ${prefix} cuts" + pieces=$(find data/fbank/ -name "${prefix}_cuts_${subset}.*.jsonl.gz") + lhotse combine $pieces data/fbank/${prefix}_cuts_${subset}.jsonl.gz + fi + if [ ! -e data/fbank/.${prefix}_split.done ]; then + echo "Splitting ${prefix} cuts into train, valid and test sets" + + lhotse subset --last 800 \ + data/fbank/${prefix}_cuts_${subset}.jsonl.gz \ + data/fbank/${prefix}_cuts_validtest.jsonl.gz + lhotse subset --first 400 \ + data/fbank/${prefix}_cuts_validtest.jsonl.gz \ + data/fbank/${prefix}_cuts_valid.jsonl.gz + lhotse subset --last 400 \ + data/fbank/${prefix}_cuts_validtest.jsonl.gz \ + data/fbank/${prefix}_cuts_test.jsonl.gz + + rm data/fbank/${prefix}_cuts_validtest.jsonl.gz + + n=$(( $(gunzip -c data/fbank/${prefix}_cuts_${subset}.jsonl.gz | wc -l) - 800 )) + lhotse subset --first $n \ + data/fbank/${prefix}_cuts_${subset}.jsonl.gz \ + data/fbank/${prefix}_cuts_train.jsonl.gz + touch data/fbank/.${prefix}_split.done + fi +fi