diff --git a/examples/mms/lid_rerank/README.md b/examples/mms/lid_rerank/README.md new file mode 100644 index 0000000000..3fa599f08a --- /dev/null +++ b/examples/mms/lid_rerank/README.md @@ -0,0 +1,115 @@ +# N-best Re-ranking for Multilingual LID+ASR +This project provides N-best re-ranking, a simple inference procedure, for improving multilingual speech recognition (ASR) "in the wild" where models are expected to first predict language identity (LID) before transcribing. Our method considers N-best LID predictions for each utterance, runs the corresponding ASR in N different languages, and then uses external features over the candidate transcriptions to determine re-rank. + +The workflow is as follows: 1) run LID+ASR inference (MMS and Whisper are supported), 2) compute external re-ranking features, 3) tune feature coefficients on dev set, and 4) apply on test set. + +For more information about our method, please refer to the paper: "Improving Multilingual ASR in the Wild Using Simple N-best Re-ranking". + +## 1) Commands to Run LID+ASR Inference + +### Data Prep +Prepare a text file with one path to a wav file in each line: +``` +#/path/to/wav/list +/path/to/audio1.wav +/path/to/audio2.wav +/path/to/audio3.wav +``` + +The following workflow also assumes that LID and ASR references are available (at least for the dev set). We use [3-letter iso codes](https://dl.fbaipublicfiles.com/mms/lid/mms1b_l4017_langs.html) for both Whisper and MMS. + +Next run either Whisper or MMS based LID+ASR. + +### Whisper +Refer to the [Whisper documentation](https://github.com/openai/whisper) for installation instructions. + +First run LID: +``` +python whisper/infer_lid.py --wavs "path/to/wav/list" --dst "path/to/lid/results" --model large-v2 --n 10 +``` +Note that the size of the N-best list is set as 10 here. + +Then run ASR, using the top-N LID predictions: +``` +python whisper/infer_asr.py --wavs "path/to/wav/list" --lids "path/to/lid/results"/nbest_lid --dst "path/to/asr/results" --model large-v2 +``` + +### MMS +Refer to the [Fairseq documentation](https://github.com/facebookresearch/fairseq/tree/main) for installation instructions. + +Prepare data and models following the [instructions from the MMS repository](https://github.com/facebookresearch/fairseq/tree/main/examples/mms). Note that the MMS backend expects a slightly different wav list format, which can be obtained via: +``` +python mms/format_wav_list.py --src "/path/to/wav/list" --dst "/path/to/wav/manifest.tsv" +``` +Note that MMS also expects LID references in a file named `"/path/to/wav/manifest.lang"`. + +Then run LID: +``` +cd "path/to/fairseq/dir" +PYTHONPATH='.' python3 examples/mms/lid/infer.py "path/to/dict/dir" --path "path/to/model" --task audio_classification --infer-manifest "path/to/wav/manifest.tsv" --output-path "path/to/lid/results" --top-k 10 +``` +Note that the size of the N-best list is set as 10 here. + +Then run ASR, using the top-N LID predictions. Since MMS uses language-specific parameters, we've parallelized inference across languages: +``` +#Split data by language +python mms/split_by_lang.py --wavs_tsv "/path/to/wav/manifest.tsv" --lid_preds "path/to/lid/results"predictions.txt --dst "path/to/data/split" + +#Write language-specific ASR python commands to an executable file +mms/make_parallel_single_runs.py --dump "path/to/data/split" --model "path/to/model" --dst "path/to/asr/results" --fairseq_dir "path/to/fairseq/dir" > run.sh + +#Running each language sequentially (you can also parallelize this) +. ./run.sh + +#Merge language-specific results back to original order +python mms/merge_by_run.py --dump "path/to/data/split" --exp "path/to/asr/results" +``` + +## 2) Commands to Compute External Re-ranking Features + +### MaLA - Large Language Model +``` +python mala/infer.py --txt "path/to/asr/results"/nbest_asr_hyp --dst "path/to/lm/results" +``` + +### NLLB - Written LID Model +Download the model from the [official source](https://github.com/facebookresearch/fairseq/tree/nllb#lid-model). + +``` +python nllb/infer.py --txt "path/to/asr/results"/nbest_asr_hyp --dst "path/to/wlid/results" --model "path/to/nllb/model" +``` + +### MMS-Zeroshot - U-roman Acoustic Model +Download the model from the [official source](https://huggingface.co/spaces/mms-meta/mms-zeroshot/tree/main). + +First run u-romanization on the N-best ASR hypotheses: +``` +python mms-zs/uromanize.py --txt "path/to/asr/results"/nbest_asr_hyp --lid "path/to/lid/results"/nbest_lid --dst "path/to/uasr/results" --model "path/to/mms-zeroshot" +``` + +Then compute the forced alignment score using the MMS-Zeroshot model: +``` +python mms-zs/falign.py --uroman_txt "path/to/uasr/results"/nbest_asr_hyp_uroman --wav "path/to/wav/list" --dst "path/to/uasr/results" --model "path/to/mms-zeroshot" +``` + +## 3) Commands to Tune Feature Coefficients +``` +python rerank/tune_coefficients.py --slid "path/to/lid/results"/slid_score --asr "path/to/asr/results"/asr_score --wlid "path/to/wlid/results"/wlid_score --lm "path/to/lm/results"/lm_score --uasr "path/to/uasr/results"/uasr_score --dst "path/to/rerank/results" --ref_lid "ground-truth/lid" --nbest_lid "path/to/lid/results"/nbest_lid --ref_asr "ground-truth/asr" --nbest_asr "path/to/asr/results"/nbest_asr_hyp +``` + +## 4) Commands to Apply on Test Set +``` +python rerank/rerank.py --slid "path/to/lid/results"/slid_score --asr "path/to/asr/results"/asr_score --wlid "path/to/wlid/results"/wlid_score --lm "path/to/lm/results"/lm_score --uasr "path/to/uasr/results"/uasr_score --dst "path/to/rerank/results" --ref_lid "ground-truth/lid" --nbest_lid "path/to/lid/results"/nbest_lid --ref_asr "ground-truth/asr" --nbest_asr "path/to/asr/results"/nbest_asr_hyp --w "path/to/rerank/results"/best_coefficients +``` + +The re-ranked LID and ASR will be in `"path/to/rerank/results"/reranked_1best_lid` and `"path/to/rerank/results"/reranked_1best_asr_hyp` respectively. + +# Citation +``` +@article{yan2024wild, + title={Improving Multilingual ASR in the Wild Using Simple N-best Re-ranking}, + author={Brian Yan, Vineel Pratap, Shinji Watanabe, Michael Auli}, + journal={arXiv}, + year={2024} +} +``` \ No newline at end of file diff --git a/examples/mms/lid_rerank/cer_langs.txt b/examples/mms/lid_rerank/cer_langs.txt new file mode 100644 index 0000000000..ca012967cd --- /dev/null +++ b/examples/mms/lid_rerank/cer_langs.txt @@ -0,0 +1,11 @@ +adx +bod +cmn +dzo +jpn +khg +khm +lao +mya +tha +yue diff --git a/examples/mms/lid_rerank/mala/infer.py b/examples/mms/lid_rerank/mala/infer.py new file mode 100755 index 0000000000..60acc164dc --- /dev/null +++ b/examples/mms/lid_rerank/mala/infer.py @@ -0,0 +1,55 @@ +from transformers import AutoTokenizer, AutoModelForCausalLM +from peft import PeftModel +from tqdm import tqdm +import argparse +import os +import torch + +parser = argparse.ArgumentParser() +parser.add_argument("--txt", type=str) +parser.add_argument("--dst", type=str) +parser.add_argument("--gpu", type=int, default=1) +args = parser.parse_args() + +if __name__ == "__main__": + if not os.path.exists(args.dst): + os.makedirs(args.dst) + + base_model = AutoModelForCausalLM.from_pretrained('meta-llama/Llama-2-7b-hf') + base_model.resize_token_embeddings(260164) + tokenizer = AutoTokenizer.from_pretrained('MaLA-LM/mala-500') + if args.gpu == 1: + model = PeftModel.from_pretrained(base_model, 'MaLA-LM/mala-500').to("cuda") + else: + model = PeftModel.from_pretrained(base_model, 'MaLA-LM/mala-500') + model.eval() + + txts = [x.strip() for x in open(args.txt, "r").readlines()] + + with open(args.dst + "/lm_score", "w", buffering=1) as f: + for t in tqdm(txts): + input_tokens = tokenizer("", add_special_tokens=True, return_tensors='pt').input_ids + if len(t) > 0: + output_tokens = tokenizer(t, add_special_tokens=False, return_tensors='pt').input_ids + tokens = torch.cat([input_tokens, output_tokens], dim=1) + length = output_tokens.shape[-1] + else: + tokens = input_tokens + length = 0 + + if args.gpu == 1: + tokens = tokens.to("cuda") + + with torch.no_grad(): + outputs = model(tokens) + logits = outputs.logits + + log_sum = 0 + for i in range(tokens.shape[-1] - 1): + past_tok, current_tok = i, i + 1 + token_logit = logits[0, past_tok, :] + token_log_probs = torch.nn.functional.log_softmax(token_logit, dim=-1) + log_token_prob = token_log_probs[tokens[0, current_tok]].item() + log_sum += log_token_prob + + f.write(str(log_sum) + "\n") \ No newline at end of file diff --git a/examples/mms/lid_rerank/mms-zs/falign.py b/examples/mms/lid_rerank/mms-zs/falign.py new file mode 100755 index 0000000000..bb170b7e99 --- /dev/null +++ b/examples/mms/lid_rerank/mms-zs/falign.py @@ -0,0 +1,87 @@ +import os +import tempfile +import re +import librosa +import torch +import json +import numpy as np +import argparse +from tqdm import tqdm +import math + +from transformers import Wav2Vec2ForCTC, AutoProcessor + +from lib import falign_ext + +parser = argparse.ArgumentParser() +parser.add_argument("--uroman_txt", type=str) +parser.add_argument("--wav", type=str) +parser.add_argument("--dst", type=str) +parser.add_argument("--model", type=str) +parser.add_argument("--n", type=int, default=10) +args = parser.parse_args() + +ASR_SAMPLING_RATE = 16_000 + +MODEL_ID = "/upload/mms_zs" + +processor = AutoProcessor.from_pretrained(args.model+MODEL_ID) +model = Wav2Vec2ForCTC.from_pretrained(args.model+MODEL_ID) + +token_file = args.model+"/upload/mms_zs/tokens.txt" + +if __name__ == "__main__": + if not os.path.exists(args.dst): + os.makedirs(args.dst) + + tokens = [x.strip() for x in open(token_file, "r").readlines()] + + txts = [x.strip() for x in open(args.uroman_txt, "r").readlines()] + wavs = [x.strip() for x in open(args.wav, "r").readlines()] + assert len(txts) == args.n * len(wavs) + + if torch.cuda.is_available(): + device = torch.device("cuda") + elif ( + hasattr(torch.backends, "mps") + and torch.backends.mps.is_available() + and torch.backends.mps.is_built() + ): + device = torch.device("mps") + else: + device = torch.device("cpu") + + model.to(device) + + # clear it + with open(args.dst + "/uasr_score", "w") as f1: + pass + + for i, w in tqdm(enumerate(wavs)): + assert isinstance(w, str) + audio_samples = librosa.load(w, sr=ASR_SAMPLING_RATE, mono=True)[0] + + inputs = processor( + audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt" + ) + inputs = inputs.to(device) + + with torch.no_grad(): + outputs = model(**inputs).logits + + emissions = outputs.log_softmax(dim=-1).squeeze() + + for j in range(args.n): + idx = (args.n * i) + j + chars = txts[idx].split() + token_sequence = [tokens.index(x) for x in chars] + + try: + _, alphas, _ = falign_ext.falign(emissions, torch.tensor(token_sequence, device=device).int(), False) + aligned_alpha = max(alphas[-1]).item() + except: + aligned_alpha = math.log(0.000000001) + + with open(args.dst + "/uasr_score", "a") as f1: + f1.write(str(aligned_alpha) + "\n") + f1.flush() \ No newline at end of file diff --git a/examples/mms/lid_rerank/mms-zs/lib.py b/examples/mms/lid_rerank/mms-zs/lib.py new file mode 100755 index 0000000000..def3b82096 --- /dev/null +++ b/examples/mms/lid_rerank/mms-zs/lib.py @@ -0,0 +1,239 @@ +import os +from dataclasses import dataclass +import torch +import torch.utils.cpp_extension + +cuda_source = """ + +#include +#include +#include +#include +#include +#include +#include + +using namespace torch::indexing; + +constexpr int kNumThreads = 1024; +constexpr float kNegInfinity = -std::numeric_limits::infinity(); +constexpr int kBlankIdx = 0; + +__global__ void +falign_cuda_step_kernel( + const torch::PackedTensorAccessor32 + emissions_a, + const torch::PackedTensorAccessor32 + target_a, + const int T, const int L, const int N, const int R, const int t, int start, + int end, torch::PackedTensorAccessor32 + runningAlpha_a, + torch::PackedTensorAccessor32 + backtrack_a, const bool normalize) +{ + int S = 2 * L + 1; + + int idx1 = (t % 2); // current time step frame for alpha + int idx2 = ((t - 1) % 2); // previous time step frame for alpha + + // reset alpha and backtrack values + for (int i = threadIdx.x; i < S; i += blockDim.x) { + runningAlpha_a[idx1][i] = kNegInfinity; + backtrack_a[i] = -1; + } + // This could potentially be removed through careful indexing inside each thread + // for the above for loop. But this is okay for now. + __syncthreads(); + + if (t == 0) { + for (int i = start + threadIdx.x; i < end; i += blockDim.x) { + int labelIdx = (i % 2 == 0) ? kBlankIdx : target_a[i / 2]; + runningAlpha_a[idx1][i] = emissions_a[0][labelIdx]; + } + return; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tempStorage; + __shared__ float maxValue; + + float threadMax; + + int startloop = start; + + threadMax = kNegInfinity; + + if (start == 0 && threadIdx.x == 0) { + runningAlpha_a[idx1][0] = + runningAlpha_a[idx2][0] + emissions_a[t][kBlankIdx]; + threadMax = max(threadMax, runningAlpha_a[idx1][0]); + + backtrack_a[0] = 0; + // startloop += 1; // startloop is threadlocal meaning it would only be changed for threads entering this loop (ie threadIdx == 0) + } + if(start == 0) { + startloop += 1; + } + + for (int i = startloop + threadIdx.x; i < end; i += blockDim.x) { + float x0 = runningAlpha_a[idx2][i]; + float x1 = runningAlpha_a[idx2][i - 1]; + float x2 = kNegInfinity; + + int labelIdx = (i % 2 == 0) ? kBlankIdx : target_a[i / 2]; + + if (i % 2 != 0 && i != 1 && target_a[i / 2] != target_a[i / 2 - 1]) { + x2 = runningAlpha_a[idx2][i - 2]; + } + + float result = 0.0; + if (x2 > x1 && x2 > x0) { + result = x2; + backtrack_a[i] = 2; + } else if (x1 > x0 && x1 > x2) { + result = x1; + backtrack_a[i] = 1; + } else { + result = x0; + backtrack_a[i] = 0; + } + + runningAlpha_a[idx1][i] = result + emissions_a[t][labelIdx]; + threadMax = max(threadMax, runningAlpha_a[idx1][i]); + } + + float maxResult = BlockReduce(tempStorage).Reduce(threadMax, cub::Max()); + if (threadIdx.x == 0) { + maxValue = maxResult; + } + + __syncthreads(); + // normalize alpha values so that they don't overflow for large T + if(normalize) { + for (int i = threadIdx.x; i < S; i += blockDim.x) { + runningAlpha_a[idx1][i] -= maxValue; + } + } +} + +std::tuple, torch::Tensor, torch::Tensor> +falign_cuda(const torch::Tensor& emissions, const torch::Tensor& target, const bool normalize=false) +{ + TORCH_CHECK(emissions.is_cuda(), "need cuda tensors"); + TORCH_CHECK(target.is_cuda(), "need cuda tensors"); + TORCH_CHECK(target.device() == emissions.device(), + "need tensors on same cuda device"); + TORCH_CHECK(emissions.dim() == 2 && target.dim() == 1, "invalid sizes"); + TORCH_CHECK(target.sizes()[0] > 0, "target size cannot be empty"); + + + + int T = emissions.sizes()[0]; // num frames + int N = emissions.sizes()[1]; // alphabet size + int L = target.sizes()[0]; // label length + const int S = 2 * L + 1; + + + auto targetCpu = target.to(torch::kCPU); + + + // backtrack stores the index offset fthe best path at current position + // We copy the values to CPU after running every time frame. + + auto backtrack = torch::zeros({ S }, torch::kInt32).to(emissions.device()); + auto backtrackCpu = torch::zeros( + { T, S }, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU)); + TORCH_CHECK(backtrack.is_cuda(), "need cuda tensors"); + TORCH_CHECK(!backtrackCpu.is_cuda(), "need cpu tensors"); + + + + // we store only two time frames for alphas + // alphas for compute current timeframe can be computed only from previous time frame. + + auto runningAlpha = + torch::zeros( + { 2, S }, + torch::TensorOptions().dtype(torch::kFloat).device(emissions.device())); + auto alphaCpu = + torch::zeros( + { T, S }, + torch::TensorOptions().dtype(torch::kFloat).device(torch::kCPU)); + TORCH_CHECK(runningAlpha.is_cuda(), "need cuda tensors"); + TORCH_CHECK(!alphaCpu.is_cuda(), "need cpu tensors"); + + auto stream = at::cuda::getCurrentCUDAStream(); + + // CUDA accessors + auto emissions_a = emissions.packed_accessor32(); + auto target_a = target.packed_accessor32(); + auto runningAlpha_a = + runningAlpha.packed_accessor32(); + auto backtrack_a = + backtrack.packed_accessor32(); + + + // CPU accessors + auto targetCpu_a = targetCpu.accessor(); + auto backtrackCpu_a = backtrackCpu.accessor(); + auto aphaCpu_a = alphaCpu.accessor(); + + // count the number of repeats in label + int R = 0; + for (int i = 1; i < L; ++i) { + if (targetCpu_a[i] == targetCpu_a[i - 1]) { + ++R; + } + } + TORCH_CHECK(T >= (L + R), "invalid sizes 2"); + + + int start = (T - (L + R)) > 0 ? 0 : 1; + int end = (S == 1) ? 1 : 2; + for (int t = 0; t < T; ++t) { + if (t > 0) { + if (T - t <= L + R) { + if ((start % 2 == 1) && + (targetCpu_a[start / 2] != targetCpu_a[start / 2 + 1])) { + start = start + 1; + } + start = start + 1; + } + if (t <= L + R) { + if ((end % 2 == 0) && (end < 2 * L) && + (targetCpu_a[end / 2 - 1] != targetCpu_a[end / 2])) { + end = end + 1; + } + end = end + 1; + } + } + falign_cuda_step_kernel<<<1, kNumThreads, 0, stream>>>( + emissions_a, target_a, T, L, N, R, t, start, end, runningAlpha_a, + backtrack_a, normalize); + + backtrackCpu.index_put_({ t, Slice()}, backtrack.to(torch::kCPU)); + alphaCpu.index_put_({ t, Slice()}, runningAlpha.slice(0, t % 2, t % 2 + 1).to(torch::kCPU)); + } + + int idx1 = ((T - 1) % 2); + int ltrIdx = runningAlpha[idx1][S - 1].item() > + runningAlpha[idx1][S - 2].item() + ? S - 1 + : S - 2; + + std::vector path(T); + for (int t = T - 1; t >= 0; --t) { + path[t] = (ltrIdx % 2 == 0) ? 0 : targetCpu_a[ltrIdx / 2]; + ltrIdx -= backtrackCpu_a[t][ltrIdx]; + } + + // returning runningAlpha, backtrackCpu for debugging purposes + return std::make_tuple(path, alphaCpu, backtrackCpu); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("falign", &falign_cuda, "falign cuda"); +} +""" +falign_ext = torch.utils.cpp_extension.load_inline("falign", cpp_sources="", cuda_sources=cuda_source, extra_cflags=['-O3'], verbose=True ) \ No newline at end of file diff --git a/examples/mms/lid_rerank/mms-zs/uromanize.py b/examples/mms/lid_rerank/mms-zs/uromanize.py new file mode 100755 index 0000000000..5615901eef --- /dev/null +++ b/examples/mms/lid_rerank/mms-zs/uromanize.py @@ -0,0 +1,69 @@ +import os +import tempfile +import re +import argparse +from tqdm import tqdm + +parser = argparse.ArgumentParser() +parser.add_argument("--txt", type=str) +parser.add_argument("--lid", type=str) +parser.add_argument("--dst", type=str) +parser.add_argument("--model", type=str) +args = parser.parse_args() + +UROMAN_PL = args.model + "uroman/bin/uroman.pl" + +def norm_uroman(text): + text = text.lower() + text = text.replace("’", "'") + text = re.sub("([^a-z' ])", " ", text) + text = re.sub(" +", " ", text) + return text.strip() + +def uromanize(words): + iso = "xxx" + with tempfile.NamedTemporaryFile() as tf, tempfile.NamedTemporaryFile() as tf2: + with open(tf.name, "w") as f: + f.write("\n".join(words)) + cmd = f"perl " + UROMAN_PL + cmd += f" -l {iso} " + cmd += f" < {tf.name} > {tf2.name}" + os.system(cmd) + lexicon = {} + with open(tf2.name) as f: + for idx, line in enumerate(f): + if not line.strip(): + continue + line = re.sub(r"\s+", "", norm_uroman(line)).strip() + lexicon[words[idx]] = " ".join(line) + " |" + return lexicon + +def convert_sent(txt, char_lang=False): + if char_lang: + words = txt + else: + words = txt.split(" ") + lexicon = uromanize(words) + pron = [] + pron_no_sp = [] + for w in words: + if w in lexicon: + pron.append(lexicon[w]) + pron_no_sp.append(lexicon[w].replace(" |", "")) + + return " ".join(pron), " ".join(pron_no_sp) + +if __name__ == "__main__": + if not os.path.exists(args.dst): + os.makedirs(args.dst) + + txts = [x.strip() for x in open(args.txt, "r").readlines()] + langs = [x.strip() for x in open(args.lid, "r").readlines()] + assert len(txts) == len(langs) + + cer_langs = [x.strip() for x in open("cer_langs.txt", "r").readlines()] + + with open(args.dst + "/nbest_asr_hyp_uroman", "w", buffering=1) as f: + for t, l in tqdm(zip(txts,langs), total=len(txts)): + pron, _ = convert_sent(t, l in cer_langs) + f.write(pron + "\n") diff --git a/examples/mms/lid_rerank/mms/make_parallel_single_runs.py b/examples/mms/lid_rerank/mms/make_parallel_single_runs.py new file mode 100755 index 0000000000..e015b14f4c --- /dev/null +++ b/examples/mms/lid_rerank/mms/make_parallel_single_runs.py @@ -0,0 +1,22 @@ +import argparse +import json +from collections import defaultdict +import os +from tqdm import tqdm +import sys +import subprocess +import re + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Example argument parser') + parser.add_argument('--dump', type=str) + parser.add_argument('--model', type=str) + parser.add_argument('--dst', type=str) + parser.add_argument('--fairseq_dir', type=str) + args = parser.parse_args() + + langs = [d for d in os.listdir(args.dump) if os.path.isdir(os.path.join(args.dump, d))] + + for lang in langs: + print(f"python mms/run_single_lang.py --dump {os.path.abspath(args.dump)} --lang {lang} --model {os.path.abspath(args.model)} --dst {os.path.abspath(args.dst)} --fairseq_dir {os.path.abspath(args.fairseq_dir)}") + \ No newline at end of file diff --git a/examples/mms/lid_rerank/mms/merge_by_lang.py b/examples/mms/lid_rerank/mms/merge_by_lang.py new file mode 100755 index 0000000000..9a643b9289 --- /dev/null +++ b/examples/mms/lid_rerank/mms/merge_by_lang.py @@ -0,0 +1,33 @@ +import argparse +import json +from collections import defaultdict +import os +import soundfile as sf +from tqdm import tqdm + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Example argument parser') + parser.add_argument('--exp', type=str) + parser.add_argument('--dump', type=str) + args = parser.parse_args() + + langs = [d for d in os.listdir(args.dump) if os.path.isdir(os.path.join(args.dump, d))] + + data = {} + + for lang in langs: + ids = [int(x.strip()) for x in open(args.dump + "/" + lang + "/ids.txt", "r").readlines()] + word_hyps = [x.strip() for x in open(args.exp + "/" + lang + "/hypo.word.reord", "r").readlines()] + scores = [x.strip() for x in open(args.exp + "/" + lang + "/asr_score.reord", "r").readlines()] + assert len(ids) == len(word_hyps) + assert len(ids) == len(scores) + for id, word_hyp, s in zip(ids, word_hyps, scores): + if id in data: + print("Duplicate ID found") + import pdb;pdb.set_trace() + data[id] = (word_hyp, s) + + with open(args.exp + "/nbest_asr_hyp", "w") as f1, open(args.exp + "/asr_score", "w") as f2: + for i in range(len(data.keys())): + f1.write(data[i][0] + "\n") + f2.write(data[i][1] + "\n") \ No newline at end of file diff --git a/examples/mms/lid_rerank/mms/prep_wav_list.py b/examples/mms/lid_rerank/mms/prep_wav_list.py new file mode 100755 index 0000000000..725274fcbe --- /dev/null +++ b/examples/mms/lid_rerank/mms/prep_wav_list.py @@ -0,0 +1,22 @@ +import soundfile as sf +import argparse + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Example argument parser') + parser.add_argument('--src', type=str) + args = parser.parse_args() + + wavs = [x.strip() for x in open(args.src, "r").readlines()] + + new_lines = ["/"] + for wav in wavs: + # Read the wav file + data, sample_rate = sf.read(wav) + + # Number of samples is the length of the data array + num_samples = len(data) + + new_lines.append(wav+"\t"+str(num_samples)) + + with open(args.dst, "w") as f: + f.writelines([x+"\n" for x in new_lines]) diff --git a/examples/mms/lid_rerank/mms/run_single_lang.py b/examples/mms/lid_rerank/mms/run_single_lang.py new file mode 100755 index 0000000000..aeebb17120 --- /dev/null +++ b/examples/mms/lid_rerank/mms/run_single_lang.py @@ -0,0 +1,65 @@ +import argparse +import json +from collections import defaultdict +import os +from tqdm import tqdm +import sys +import subprocess +import re + +mapping = {"cmn":"cmn-script_simplified", "srp":"srp-script_latin", "urd":"urd-script_arabic", "uzb":"uzb-script_latin", "yue":"yue-script_traditional", "aze":"azj-script_latin", "kmr":"kmr-script_latin"} + +def reorder_decode(hypos): + outputs = [] + for hypo in hypos: + idx = int(re.findall("\(None-(\d+)\)$", hypo)[0]) + hypo = re.sub("\(\S+\)$", "", hypo).strip() + outputs.append((idx, hypo)) + outputs = sorted(outputs) + return outputs + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Example argument parser') + parser.add_argument('--dump', type=str) + parser.add_argument('--model', type=str) + parser.add_argument('--fairseq_dir', type=str) + parser.add_argument('--dst', type=str) + parser.add_argument('--lang', type=str) + args = parser.parse_args() + + if not os.path.exists(args.dst): + os.makedirs(args.dst) + lang = args.lang + dst = args.dst + "/" + lang + if not os.path.exists(dst): + os.makedirs(dst) + dump = args.dump + "/" + lang + if lang in mapping: + lang_code = mapping[lang] + else: + lang_code = lang + + cmd = f""" + cd {args.fairseq_dir}/ &&\ + PYTHONPATH=. PREFIX=INFER HYDRA_FULL_ERROR=1 python examples/speech_recognition/new/infer.py -m --config-dir examples/mms/asr/config/ --config-name infer_common decoding.type=viterbi dataset.max_tokens=1440000 distributed_training.distributed_world_size=1 "common_eval.path='{args.model}'" task.data={dump} dataset.gen_subset="{lang_code}:test" common_eval.post_process=letter decoding.results_path={dst} &&\ + cd - + """ + + print(cmd, file=sys.stderr) + print(f">>> {lang}", file=sys.stderr) + try: + subprocess.run(cmd, shell=True, stdout=subprocess.DEVNULL,) + with open(dst + "/hypo.word") as fr, open(dst + "/hypo.word.reord", "w") as fw: + hypos = fr.readlines() + outputs = reorder_decode(hypos) + fw.writelines([re.sub("\(\S+\)$", "", hypo).strip() + "\n" for ii,hypo in outputs]) + with open(dst + "/asr_score") as fr, open(dst + "/asr_score.reord", "w") as fw: + hypos = fr.readlines() + outputs = reorder_decode(hypos) + fw.writelines([re.sub("\(\S+\)$", "", hypo).strip() + "\n" for ii,hypo in outputs]) + except: + print(f"Something went wrong with {lang}. If {lang} is not supported by the ASR model, then this is expected and OK. If it is supported, then something else has gone wrong unexpectedly.", file=sys.stderr) + with open(dst + "/hypo.word.reord", "w") as fw: + fw.writelines(["\n"] * len(open(dump+"/ids.txt", "r").readlines())) + with open(dst + "/asr_score.reord", "w") as fw: + fw.writelines(["\n"] * len(open(dump+"/ids.txt", "r").readlines())) \ No newline at end of file diff --git a/examples/mms/lid_rerank/mms/split_by_lang.py b/examples/mms/lid_rerank/mms/split_by_lang.py new file mode 100755 index 0000000000..b123e406f2 --- /dev/null +++ b/examples/mms/lid_rerank/mms/split_by_lang.py @@ -0,0 +1,90 @@ +import argparse +import json +from collections import defaultdict +import os +import soundfile as sf +from tqdm import tqdm + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Example argument parser') + parser.add_argument('--wavs_tsv', type=str) + parser.add_argument('--lid_preds', type=str) + parser.add_argument('--dst', type=str) + parser.add_argument('--refs', type=str, default=None) + parser.add_argument('--langs', type=str, default=None) + parser.add_argument('--confs', type=str, default=None) + args = parser.parse_args() + + # split wavs into dst/lang/wav.txt and dst/lang/ids.txt + # uses lid_preds to create topk asr; 1 wav has k different lid + + wavs_tsv = [x for x in open(args.wavs_tsv, "r").readlines()] + root = wavs_tsv[0] + wavs = wavs_tsv[1:] + lid_preds = [eval(x) for x in open(args.lid_preds, "r").readlines()] + if args.refs is not None: + refs = [x.strip() for x in open(args.refs, "r").readlines()] + assert len(wavs) == len(refs) + refs_filt = [] + if args.langs is not None: + langs = [x.strip() for x in open(args.langs, "r").readlines()] + assert len(wavs) == len(langs) + langs_filt = [] + if args.confs is not None: + confs = [x.strip() for x in open(args.confs, "r").readlines()] + assert len(wavs) == len(confs) + confs_filt = [] + + assert len(wavs) == len(lid_preds) + + topk_wavs = [] + topk_langs = [] + + for i, (w, p) in enumerate(zip(wavs, lid_preds)): + if p == "n/a": + continue + + assert len(p) == len(lid_preds[0]) + + for l, _ in p: + topk_wavs.append(w) + topk_langs.append(l) + + if args.refs is not None: + refs_filt.append(refs[i]) + if args.langs is not None: + langs_filt.append(langs[i]) + if args.confs is not None: + confs_filt.append(confs[i]) + + lang_split = defaultdict(list) + for id, (wav,lid) in enumerate(zip(topk_wavs, topk_langs)): + lang_split[lid].append((id, wav)) + + for lang in tqdm(lang_split.keys()): + if not os.path.exists(args.dst + "/" + lang): + os.makedirs(args.dst + "/" + lang) + + with open(args.dst + "/" + lang + "/test.tsv", "w") as f1, \ + open(args.dst + "/" + lang + "/ids.txt", "w") as f2: + f1.write(root) + f1.writelines([x[1] for x in lang_split[lang]]) + f2.writelines([str(x[0]) + "\n" for x in lang_split[lang]]) + + with open(args.dst + "/" + lang + "/test.ltr", "w") as fw: + fw.write("d u m m y | d u m m y |\n"*len(lang_split[lang])) + with open(args.dst + "/" + lang + "/test.wrd", "w") as fw: + fw.write("dummy dummy\n"*len(lang_split[lang])) + + with open(args.dst + "/lid.txt", "w") as f: + f.writelines([x+"\n" for x in topk_langs]) + + if args.refs is not None: + with open(args.dst + "/refs.txt", "w") as f: + f.writelines([x+"\n" for x in refs_filt]) + if args.langs is not None: + with open(args.dst + "/langs.txt", "w") as f: + f.writelines([x+"\n" for x in langs_filt]) + if args.confs is not None: + with open(args.dst + "/confs.txt", "w") as f: + f.writelines([x+"\n" for x in confs_filt]) \ No newline at end of file diff --git a/examples/mms/lid_rerank/nllb/infer.py b/examples/mms/lid_rerank/nllb/infer.py new file mode 100755 index 0000000000..1f4d69a907 --- /dev/null +++ b/examples/mms/lid_rerank/nllb/infer.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# -*- encoding: utf8 -*- +import fasttext +from tqdm import tqdm +import argparse +import os +import math + +parser = argparse.ArgumentParser() +parser.add_argument("--txt", type=str) +parser.add_argument("--dst", type=str) +parser.add_argument("--model", type=str) +parser.add_argument('--lid', type=str) +args = parser.parse_args() + +mapping = {"arb":"ara", "azj":"aze", "pes":"fas", "fuv":"ful", "lvs":"lav", "khk":"mon", "zsm":"zlm", "gaz":"orm", "pbt":"pus", "uzn":"uzb", "zho":"cmn"} + +def fix_code(x): + code = x.split("_")[-2] + if code in mapping: + code = mapping[code] + return code + +if __name__ == "__main__": + if not os.path.exists(args.dst): + os.makedirs(args.dst) + + pretrained_lang_model = args.model + model = fasttext.load_model(pretrained_lang_model) + + txts = [x.strip() for x in open(args.txt, "r").readlines()] + lids = [x.strip() for x in open(args.lid, "r").readlines()] + assert len(txts) == len(lids) + + with open(args.dst + "/wlid_score", "w") as f: + for t,l in tqdm(zip(txts, lids)): + predictions = model.predict(t, k=218) # max 218 + predictions = [(fix_code(x), y) for x, y in zip(predictions[0], predictions[1])] + + try: + pred_langs = [x[0] for x in predictions] + idx = pred_langs.index(l) + score = math.log(predictions[idx][-1]) + except: + score = -1000 + f.write(str(score) + "\n") \ No newline at end of file diff --git a/examples/mms/lid_rerank/requirements.txt b/examples/mms/lid_rerank/requirements.txt new file mode 100644 index 0000000000..459c37ee20 --- /dev/null +++ b/examples/mms/lid_rerank/requirements.txt @@ -0,0 +1,10 @@ +transformers +peft +protobuf +blobfile +sentencepiece +fasttext +numpy<=1.26.4 +librosa +ninja +editdistance \ No newline at end of file diff --git a/examples/mms/lid_rerank/rerank/rerank.py b/examples/mms/lid_rerank/rerank/rerank.py new file mode 100755 index 0000000000..beea3e6a77 --- /dev/null +++ b/examples/mms/lid_rerank/rerank/rerank.py @@ -0,0 +1,132 @@ +import argparse +import json +from collections import defaultdict +import os +from tqdm import tqdm +import sys +import subprocess +import re +import math +import numpy as np +import editdistance +from sklearn.preprocessing import StandardScaler +from multiprocessing import Pool +from functools import partial +import random + +cer_langs = [x.strip() for x in open("cer_langs.txt", "r").readlines()] + +def select(w, feats, ref_lid, nbest_lid, ref_asr, nbest_asr, n=10, exclude=None): + assert len(w) == len(feats[0]) + scores = [] + for f in feats: + s = 0 + for i in range(len(w)): + s += w[i]*f[i] + scores.append(s) + + lid_correct = 0 + lid_total = 0 + asr_err = 0 + asr_total = 0 + text = [] + lang = [] + + for i in range(len(ref_lid)): + if exclude is not None: + if ref_lid[i] in exclude: + continue + + start_idx = i * n + end_idx = start_idx + n + cand_scores = scores[start_idx:end_idx] + max_idx, max_val = max(enumerate(cand_scores), key=lambda x: x[1]) + + cand_feats = feats[start_idx:end_idx] + + lang.append(nbest_lid[start_idx:end_idx][max_idx]) + if ref_lid[i] == nbest_lid[start_idx:end_idx][max_idx]: + lid_correct += 1 + lid_total += 1 + + hyp = nbest_asr[start_idx:end_idx][max_idx] + text.append(hyp) + ref = ref_asr[i] + hyp = hyp.lower() + ref = ref.lower() + hyp = hyp.replace(".", "").replace(",", "").replace("?", "").replace("!", "").replace(":", "").replace(")", "").replace("(", "").replace("-", "") + ref = ref.replace(".", "").replace(",", "").replace("?", "").replace("!", "").replace(":", "").replace(")", "").replace("(", "").replace("-", "") + if ref_lid[i] in cer_langs: + hyp = " ".join(hyp) + ref = " ".join(ref) + + hyp_words = hyp.split() + tgt_words = ref.split() + errs = editdistance.eval(hyp_words, tgt_words) + asr_err += errs + asr_total += len(tgt_words) + + results = {"lid_acc": lid_correct / lid_total, "asr_wer": asr_err / asr_total, "weights": w} + + return results, text, lang + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Example argument parser') + parser.add_argument('--slid', type=str) + parser.add_argument('--wlid', type=str) + parser.add_argument('--asr', type=str) + parser.add_argument('--lm', type=str) + parser.add_argument('--uasr', type=str) + parser.add_argument('--n', type=int, default=10) + parser.add_argument('--dst', type=str) + parser.add_argument('--ref_lid', type=str) + parser.add_argument('--nbest_lid', type=str) + parser.add_argument('--ref_asr', type=str) + parser.add_argument('--nbest_asr', type=str) + parser.add_argument('--w', type=str) + parser.add_argument('--tag', type=str, default = None) + parser.add_argument('--exclude', nargs="*", default=None) # exclude langs + args = parser.parse_args() + + slid = [float(x.strip()) for x in open(args.slid, "r").readlines()] + wlid = [float(x.strip()) for x in open(args.wlid, "r").readlines()] + asr = [float(x.strip()) for x in open(args.asr, "r").readlines()] + lm = [float(x.strip()) for x in open(args.lm, "r").readlines()] + uasr = [float(x.strip()) for x in open(args.uasr, "r").readlines()] + + assert len(slid) == len(wlid) + assert len(wlid) == len(asr) + assert len(asr) == len(lm) + assert len(lm) == len(uasr) + + ref_lid = [x.strip() for x in open(args.ref_lid, "r").readlines()] + nbest_lid= [x.strip() for x in open(args.nbest_lid, "r").readlines()] + ref_asr = [x.strip() for x in open(args.ref_asr, "r").readlines()] + nbest_asr = [x.strip() for x in open(args.nbest_asr, "r").readlines()] + + assert len(ref_lid) * args.n == len(nbest_lid) + assert len(ref_asr) * args.n == len(nbest_asr) + assert len(ref_lid) == len(ref_asr) + + lengths = [len(x) for x in nbest_asr] + + feats = [[s, w, a, l, u, le] for s,w,a,l,u,le in zip(slid, wlid, asr, lm, uasr, lengths)] + + weight = eval(open(args.w, "r").read())['weights'] + + results, text, lang = select(weight, feats, ref_lid, nbest_lid, ref_asr, nbest_asr, n=args.n, exclude=args.exclude) + + if args.tag is not None: + tag_text = "." + args.tag + else: + tag_text = "" + + with open(args.dst + "/reranked_1best_asr_hyp" + tag_text, "w") as f_out: + f_out.writelines([x+"\n" for x in text]) + + with open(args.dst + "/reranked_1best_lid" + tag_text, "w") as f_out: + f_out.writelines([x+"\n" for x in lang]) + + with open(args.dst + "/text.result" + tag_text, "w") as f_out: + for k in results.keys(): + f_out.write(k + "\t" + str(results[k]) + "\n") diff --git a/examples/mms/lid_rerank/rerank/tune_coefficients.py b/examples/mms/lid_rerank/rerank/tune_coefficients.py new file mode 100755 index 0000000000..fc15f650a7 --- /dev/null +++ b/examples/mms/lid_rerank/rerank/tune_coefficients.py @@ -0,0 +1,138 @@ +import argparse +import os +from tqdm import tqdm +import numpy as np +import editdistance +from multiprocessing import Pool +from functools import partial + +cer_langs = [x.strip() for x in open("cer_langs.txt", "r").readlines()] + +def compute(w, feats, ref_lid, nbest_lid, ref_asr, nbest_asr, n=10, exclude=None): + assert len(w) == len(feats[0]) + scores = [] + for f in feats: + s = 0 + for i in range(len(w)): + s += w[i]*f[i] + scores.append(s) + + lid_correct = 0 + lid_total = 0 + asr_err = 0 + asr_total = 0 + + for i in range(len(ref_lid)): + if exclude is not None: + if ref_lid[i] in exclude: + continue + + start_idx = i * n + end_idx = start_idx + n + cand_scores = scores[start_idx:end_idx] + max_idx, max_val = max(enumerate(cand_scores), key=lambda x: x[1]) + + if ref_lid[i] == nbest_lid[start_idx:end_idx][max_idx]: + lid_correct += 1 + lid_total += 1 + + hyp = nbest_asr[start_idx:end_idx][max_idx] + ref = ref_asr[i] + hyp = hyp.lower() + ref = ref.lower() + hyp = hyp.replace(".", "").replace(",", "").replace("?", "").replace("!", "").replace(":", "").replace(")", "").replace("(", "").replace("-", "") + ref = ref.replace(".", "").replace(",", "").replace("?", "").replace("!", "").replace(":", "").replace(")", "").replace("(", "").replace("-", "") + if ref_lid[i] in cer_langs: + hyp = " ".join(hyp) + ref = " ".join(ref) + + hyp_words = hyp.split() + tgt_words = ref.split() + errs = editdistance.eval(hyp_words, tgt_words) + asr_err += errs + asr_total += len(tgt_words) + + return {"lid_acc": lid_correct / lid_total, "asr_wer": asr_err / asr_total, "weights": w} + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Example argument parser') + parser.add_argument('--slid', type=str) + parser.add_argument('--wlid', type=str) + parser.add_argument('--asr', type=str) + parser.add_argument('--lm', type=str) + parser.add_argument('--uasr', type=str) + parser.add_argument('--n', type=int, default=10) + parser.add_argument('--dst', type=str) + parser.add_argument('--ref_lid', type=str) + parser.add_argument('--nbest_lid', type=str) + parser.add_argument('--ref_asr', type=str) + parser.add_argument('--nbest_asr', type=str) + parser.add_argument('--iters', type=int, default=10000) + parser.add_argument('--slid_scale', type=int, default = 100) + parser.add_argument('--wlid_scale', type=int, default = 100) + parser.add_argument('--asr_scale', type=int, default = 10) + parser.add_argument('--lm_scale', type=int, default = 10) + parser.add_argument('--uasr_scale', type=int, default = 10) + parser.add_argument('--len_scale', type=int, default = 1) + parser.add_argument('--num_jobs', type=int, default = 64) + parser.add_argument('--exclude', nargs="*", default=None) # exclude langs + args = parser.parse_args() + + slid = [float(x.strip()) for x in open(args.slid, "r").readlines()] + wlid = [float(x.strip()) for x in open(args.wlid, "r").readlines()] + asr = [float(x.strip()) for x in open(args.asr, "r").readlines()] + lm = [float(x.strip()) for x in open(args.lm, "r").readlines()] + uasr = [float(x.strip()) for x in open(args.uasr, "r").readlines()] + + assert len(slid) == len(wlid) + assert len(wlid) == len(asr) + assert len(asr) == len(lm) + assert len(lm) == len(uasr) + + ref_lid = [x.strip() for x in open(args.ref_lid, "r").readlines()] + nbest_lid= [x.strip() for x in open(args.nbest_lid, "r").readlines()] + ref_asr = [x.strip() for x in open(args.ref_asr, "r").readlines()] + nbest_asr = [x.strip() for x in open(args.nbest_asr, "r").readlines()] + + assert len(ref_lid) * args.n == len(nbest_lid) + assert len(ref_asr) * args.n == len(nbest_asr) + assert len(ref_lid) == len(ref_asr) + + lengths = [len(x) for x in nbest_asr] + + feats = [[s, w, a, l, u, le] for s,w,a,l,u,le in zip(slid, wlid, asr, lm, uasr, lengths)] + + weights = [] + for i in range(args.iters): + s_w = np.random.rand() * args.slid_scale + w_w = np.random.rand() * args.wlid_scale + a_w = np.random.rand() * args.asr_scale + l_w = np.random.rand() * args.lm_scale + u_w = np.random.rand() * args.uasr_scale + le_w = (np.random.rand() -0.5) * args.len_scale + weights.append([s_w, w_w, a_w, l_w, u_w, le_w]) + + num_tries = len(weights) + print("Total number of search points", num_tries) + threads = args.num_jobs + pool = Pool(threads) + compute_fxn = partial(compute, feats=feats, ref_lid=ref_asr, nbest_lid=nbest_lid, ref_asr=ref_asr, nbest_asr=nbest_asr, n=args.n, exclude=args.exclude) + results = pool.map(compute_fxn, weights) + pool.close() + pool.join() + + assert len(results) == len(weights) + + wer_best = 100 + best = "" + if not os.path.exists(args.dst): + os.makedirs(args.dst) + with open(args.dst + "/results.all", "w") as f_out: + for result in results: + f_out.write(str(result)+"\n") + if result["asr_wer"] < wer_best: + wer_best = result["asr_wer"] + best = result + + with open(args.dst + "/best_coefficients", "w") as f_out: + f_out.write(str(best)+"\n") \ No newline at end of file diff --git a/examples/mms/lid_rerank/whisper/infer_asr.py b/examples/mms/lid_rerank/whisper/infer_asr.py new file mode 100755 index 0000000000..0df16d0e45 --- /dev/null +++ b/examples/mms/lid_rerank/whisper/infer_asr.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 +# -*- encoding: utf8 -*- +import argparse +import itertools +import os +import re +import sys +from pathlib import Path + +import whisper +from tqdm import tqdm + + +parser = argparse.ArgumentParser() +parser.add_argument("--wavs", type=str) +parser.add_argument("--lids", type=str) +parser.add_argument("--dst", type=str) +parser.add_argument("--beam_size", type=int, default=1) +parser.add_argument("--model", type=str) +parser.add_argument("--mapping", type=str, default="whisper/lid_mapping.txt") +parser.add_argument("--n", type=int, default=10) + +args = parser.parse_args() + +if __name__ == "__main__": + model = whisper.load_model(args.model) + + print(args) + + wavs = [y for y in [x.strip() for x in open(args.wavs, "r").readlines()] for _ in range(args.n)] + lids = [x.strip() for x in open(args.lids, "r").readlines()] + assert len(wavs) == len(lids) + + if args.mapping is not None: + # mms_lid_code:whisper_lid_code + mapping = {x[1]:x[0] for x in [l.strip().split(";", 1) for l in open(args.mapping, "r").readlines()]} + else: + mapping = None + + if not os.path.exists(args.dst): + os.makedirs(args.dst) + + # clear it + with open(args.dst + "/nbest_asr_hyp", "w") as f1, open(args.dst + "/asr_score", "w") as f2: + pass + + for wav, lang in tqdm(zip(wavs, lids)): + # load audio and pad/trim it to fit 30 seconds + audio = whisper.load_audio(wav) + audio = whisper.pad_or_trim(audio) + + # make log-Mel spectrogram and move to the same device as the model + mel = whisper.log_mel_spectrogram(audio).to(model.device) + + if mapping is not None and lang in mapping.keys(): + lang_code = mapping[lang] + else: + lang_code = lang + + # decode the audio + options = whisper.DecodingOptions(beam_size=args.beam_size, language=lang_code) + output = whisper.decode(model, mel, options) + result = output.text + length = len(output.tokens) + score = output.avg_logprob * length + + with open(args.dst + "/nbest_asr_hyp", "a") as f1, open(args.dst + "/asr_score", "a") as f2: + f1.write(result + "\n") + f2.write(str(score) + "\n") + f1.flush() + f2.flush() \ No newline at end of file diff --git a/examples/mms/lid_rerank/whisper/infer_lid.py b/examples/mms/lid_rerank/whisper/infer_lid.py new file mode 100755 index 0000000000..150e0bbcca --- /dev/null +++ b/examples/mms/lid_rerank/whisper/infer_lid.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# -*- encoding: utf8 -*- +import argparse +import itertools +import os +import re +import sys +from pathlib import Path +import math + +import whisper +from tqdm import tqdm + + +parser = argparse.ArgumentParser() +parser.add_argument("--wavs", type=str) +parser.add_argument("--dst", type=str) +parser.add_argument("--model", type=str) +parser.add_argument("--n", type=int, default=10) +parser.add_argument("--mapping", type=str, default="whisper/lid_mapping.txt") +args = parser.parse_args() + +if __name__ == "__main__": + model = whisper.load_model(args.model) + + print(args) + + wavs = [x.strip() for x in open(args.wavs, "r").readlines()] + if not os.path.exists(args.dst): + os.makedirs(args.dst) + + if args.mapping is not None: + #whisper_lid_code:mms_lid_code + mapping = {x[0]:x[1] for x in [l.strip().split(";", 1) for l in open(args.mapping, "r").readlines()]} + else: + mapping = None + + with open(args.dst + "/predictions", "w") as f: + for wav in tqdm(wavs): + # load audio and pad/trim it to fit 30 seconds + audio = whisper.load_audio(wav) + audio = whisper.pad_or_trim(audio) + + # make log-Mel spectrogram and move to the same device as the model + mel = whisper.log_mel_spectrogram(audio).to(model.device) + + _, probs = model.detect_language(mel) + result = sorted(probs.items(), key=lambda x:x[1], reverse=True)[:args.n] + f.write(str(result) + "\n") + + lid_preds = [eval(x) for x in open(args.dst + "/predictions", "r").readlines()] + lids = [] + scores = [] + for p in lid_preds: + assert len(p) == len(lid_preds[0]) + for l, s in p: + if args.mapping is not None: + lids.append(mapping[l]) + else: + lids.append(l) + scores.append(math.log(s)) + with open(args.dst + "/nbest_lid", "w") as f: + f.writelines([x+"\n" for x in lids]) + with open(args.dst + "/slid_score", "w") as f: + f.writelines([str(x)+"\n" for x in scores]) \ No newline at end of file diff --git a/examples/mms/lid_rerank/whisper/lid_mapping.txt b/examples/mms/lid_rerank/whisper/lid_mapping.txt new file mode 100755 index 0000000000..ea676fece2 --- /dev/null +++ b/examples/mms/lid_rerank/whisper/lid_mapping.txt @@ -0,0 +1,99 @@ +en;eng +zh;cmn +de;deu +es;spa +ru;rus +ko;kor +fr;fra +ja;jpn +pt;por +tr;tuk +pl;pol +ca;cat +nl;nld +ar;ara +sv;swe +it;ita +id;ind +hi;hin +fi;fin +vi;vie +he;heb +uk;ukr +el;ell +ms;zlm +cs;cez +ro;ron +da;dan +hu;hun +ta;tam +no;nob +th;tha +ur;urd +hr;hrv +bg;bul +lt;lit +la;lat +mi;mri +ml;mal +cy;cym +sk;slk +te;tel +fa;fas +lv;lav +bn;ben +sr;srp +az;aze +sl;slv +kn;kan +et;est +mk;mkd +br;bre +eu;eus +is;isl +hy;hye +ne;npi +mn;mon +bs;bos +kk;kaz +sq;sqi +sw;swh +gl;glg +mr;mar +pa;pan +si;sin +km;khm +sn;sna +yo;yor +so;som +af;afr +oc;oci +ka;kat +be;bel +tg;tgk +sd;snd +gu;guj +am;amh +yi;yid +lo;lao +uz;uzb +fo;fao +ht;hat +ps;pus +tk;tuk +nn;nno +mt;mlk +sa;san +lb;ltz +my;mya +bo;bod +tl;tgl +mg;mlg +as;asm +tt;tat +haw;haw +ln;lin +ha;hau +ba;bak +jw;jav +su;sun diff --git a/examples/speech_recognition/new/decoders/viterbi_decoder.py b/examples/speech_recognition/new/decoders/viterbi_decoder.py index b1c47868fa..a35d95e146 100644 --- a/examples/speech_recognition/new/decoders/viterbi_decoder.py +++ b/examples/speech_recognition/new/decoders/viterbi_decoder.py @@ -18,7 +18,7 @@ def decode( emissions: torch.FloatTensor, ) -> List[List[Dict[str, torch.LongTensor]]]: def get_pred(e): + score = e.log_softmax(dim=-1).max(dim=-1)[0].sum() toks = e.argmax(dim=-1).unique_consecutive() - return toks[toks != self.blank] - - return [[{"tokens": get_pred(x), "score": 0}] for x in emissions] + return {"tokens":toks[toks != self.blank], "score":score} + return [[get_pred(x)] for x in emissions] diff --git a/examples/speech_recognition/new/infer.py b/examples/speech_recognition/new/infer.py index 33ed526907..ca5cea4a7c 100644 --- a/examples/speech_recognition/new/infer.py +++ b/examples/speech_recognition/new/infer.py @@ -144,6 +144,7 @@ def __init__(self, cfg: InferConfig) -> None: self.hypo_units_file = None self.ref_words_file = None self.ref_units_file = None + self.score_file = None self.progress_bar = self.build_progress_bar() @@ -153,6 +154,7 @@ def __enter__(self) -> "InferenceProcessor": self.hypo_units_file = self.get_res_file("hypo.units") self.ref_words_file = self.get_res_file("ref.word") self.ref_units_file = self.get_res_file("ref.units") + self.score_file = self.get_res_file("asr_score") return self def __exit__(self, *exc) -> bool: @@ -161,6 +163,7 @@ def __exit__(self, *exc) -> bool: self.hypo_units_file.close() self.ref_words_file.close() self.ref_units_file.close() + self.score_file.close() return False def __iter__(self) -> Any: @@ -290,7 +293,6 @@ def process_sentence( batch_id: int, ) -> Tuple[int, int]: speaker = None # Speaker can't be parsed from dataset. - if "target_label" in sample: toks = sample["target_label"] else: @@ -314,6 +316,7 @@ def process_sentence( print(f"{hyp_words} ({speaker}-{sid})", file=self.hypo_words_file) print(f"{tgt_pieces} ({speaker}-{sid})", file=self.ref_units_file) print(f"{tgt_words} ({speaker}-{sid})", file=self.ref_words_file) + print(f"{hypo['score'].item()} ({speaker}-{sid})", file=self.score_file) if not self.cfg.common_eval.quiet: logger.info(f"HYPO: {hyp_words}")