From 6a0e41b539c77d6967b3b717e27c2d4c854214cd Mon Sep 17 00:00:00 2001 From: pkufool Date: Wed, 9 Oct 2024 11:16:38 +0800 Subject: [PATCH] Minor fixes to shallow fussion --- egs/librispeech/ASR/zipformer/ctc_decode.py | 126 +++++++++++++++++--- icefall/decode.py | 51 ++++---- 2 files changed, 134 insertions(+), 43 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/ctc_decode.py b/egs/librispeech/ASR/zipformer/ctc_decode.py index 8f3dd10d27..183d42360b 100755 --- a/egs/librispeech/ASR/zipformer/ctc_decode.py +++ b/egs/librispeech/ASR/zipformer/ctc_decode.py @@ -111,6 +111,7 @@ import argparse import logging import math +import os from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -288,7 +289,7 @@ def get_parser(): ) parser.add_argument( - "--lm-type", + "--nnlm-type", type=str, default="rnn", help="Type of NN lm", @@ -296,10 +297,10 @@ def get_parser(): ) parser.add_argument( - "--lm-scale", + "--nnlm-scale", type=float, - default=0.3, - help="""The scale of the neural network LM + default=0, + help="""The scale of the neural network LM, 0 means don't use nnlm shallow fussion. Used only when `--use-shallow-fusion` is set to True. """, ) @@ -321,6 +322,47 @@ def get_parser(): """, ) + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + + parser.add_argument( + "--lodr-ngram", + type=str, + help="The path to the lodr ngram", + ) + + parser.add_argument( + "--lodr-lm-scale", + type=float, + default=0, + help="The scale of lodr ngram, should be less than 0. 0 means don't use lodr.", + ) + + parser.add_argument( + "--context-score", + type=float, + default=0, + help=""" + The bonus score of each token for the context biasing words/phrases. + 0 means don't use contextual biasing. + Used only when --decoding-method is ctc-prefix-beam-search-shallow-fussion. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding-method is ctc-prefix-beam-search-shallow-fussion. + """, + ) + parser.add_argument( "--skip-scoring", type=str2bool, @@ -358,7 +400,9 @@ def decode_one_batch( batch: dict, word_table: k2.SymbolTable, G: Optional[k2.Fsa] = None, - LM: Optional[LmScorer] = None, + NNLM: Optional[LmScorer] = None, + LODR_lm: Optional[NgramLm] = None, + context_graph: Optional[ContextGraph] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -466,7 +510,10 @@ def decode_one_batch( token_ids = ctc_prefix_beam_search_shallow_fussion( ctc_output=ctc_output, encoder_out_lens=encoder_out_lens, - LM=LM, + NNLM=NNLM, + LODR_lm=LODR_lm, + LODR_lm_scale=params.lodr_lm_scale, + context_graph=context_graph, ) # hyps is a list of str, e.g., ['xxx yyy zzz', ...] hyps = bpe_model.decode(token_ids) @@ -649,7 +696,9 @@ def decode_dataset( bpe_model: Optional[spm.SentencePieceProcessor], word_table: k2.SymbolTable, G: Optional[k2.Fsa] = None, - LM: Optional[LmScorer] = None, + NNLM: Optional[LmScorer] = None, + LODR_lm: Optional[NgramLm] = None, + context_graph: Optional[ContextGraph] = None, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -700,7 +749,9 @@ def decode_dataset( batch=batch, word_table=word_table, G=G, - LM=LM, + NNLM=NNLM, + LODR_lm=LODR_lm, + context_graph=context_graph, ) for name, hyps in hyps_dict.items(): @@ -835,7 +886,12 @@ def main(): if "prefix-beam-search" in params.decoding_method: params.suffix += f"_beam-{params.beam}" if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion": - params.suffix += f"_lm-scale-{params.lm_scale}" + if params.nnlm_scale != 0: + params.suffix += f"_nnlm-scale-{params.nnlm_scale}" + if params.lodr_lm_scale != 0: + params.suffix += f"_lodr-scale-{params.lodr_lm_scale}" + if params.context_score != 0: + params.suffix += f"_context_score-{params.context_score}" if params.use_averaged_model: params.suffix += "_use-averaged-model" @@ -947,17 +1003,49 @@ def main(): G = None # only load the neural network LM if required - if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion": - LM = LmScorer( - lm_type=params.lm_type, + NNLM = None + if ( + params.decoding_method == "ctc-prefix-beam-search-shallow-fussion" + and params.nnlm_scale != 0 + ): + NNLM = LmScorer( + lm_type=params.nnlm_type, params=params, device=device, - lm_scale=params.lm_scale, + lm_scale=params.nnlm_scale, ) - LM.to(device) - LM.eval() - else: - LM = None + NNLM.to(device) + NNLM.eval() + + LODR_lm = None + if ( + params.decoding_method == "ctc-prefix-beam-search-shallow-fussion" + and params.lodr_lm_scale != 0 + ): + assert os.path.exists( + params.lodr_ngram + ), f"LODR ngram does not exists, given path : {params.lodr_ngram}" + logging.info(f"Loading LODR (token level lm): {params.lodr_ngram}") + LODR_lm = NgramLm( + params.lodr_ngram, + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {LODR_lm.lm.num_states}") + + context_graph = None + if ( + params.decoding_method == "ctc-prefix-beam-search-shallow-fussion" + and params.context_score != 0 + ): + assert os.path.exists( + params.context_file + ), f"context_file does not exists, given path : {params.context_file}" + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append(bpe_model.encode(line.strip())) + context_graph = ContextGraph(params.context_score) + context_graph.build(contexts) logging.info("About to create model") model = get_model(params) @@ -1068,7 +1156,9 @@ def main(): bpe_model=bpe_model, word_table=lexicon.word_table, G=G, - LM=LM, + NNLM=NNLM, + LODR_lm=LODR_lm, + context_graph=context_graph, ) save_asr_output( diff --git a/icefall/decode.py b/icefall/decode.py index 777f9e3e84..6b642c94d8 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -1736,7 +1736,7 @@ def _step_worker( B: HypothesisList, beam: int = 4, blank_id: int = 0, - lm_scale: float = 0, + nnlm_scale: float = 0, LODR_lm_scale: float = 0, context_graph: Optional[ContextGraph] = None, ) -> HypothesisList: @@ -1815,14 +1815,16 @@ def _step_worker( if update_prefix: lm_score = hyp.lm_score if hyp.lm_log_probs is not None: - lm_score += hyp.lm_log_probs[new_token] * lm_scale + lm_score = lm_score + hyp.lm_log_probs[new_token] * nnlm_scale new_hyp.lm_log_probs = None if context_graph is not None and hyp.context_state is not None: - context_score, new_context_state = context_graph.forward_one_step( - hyp.context_state, new_token - ) - lm_score += context_score + ( + context_score, + new_context_state, + matched_state, + ) = context_graph.forward_one_step(hyp.context_state, new_token) + lm_score = lm_score + context_score new_hyp.context_state = new_context_state if hyp.LODR_state is not None: @@ -1833,7 +1835,7 @@ def _step_worker( state_cost.lm_score, hyp.LODR_state.lm_score, ) - lm_score += LODR_lm_scale * current_ngram_score + lm_score = lm_score + LODR_lm_scale * current_ngram_score new_hyp.LODR_state = state_cost new_hyp.lm_score = lm_score @@ -1944,7 +1946,7 @@ def ctc_prefix_beam_search_shallow_fussion( blank_id: int = 0, LODR_lm: Optional[NgramLm] = None, LODR_lm_scale: Optional[float] = 0, - LM: Optional[LmScorer] = None, + NNLM: Optional[LmScorer] = None, context_graph: Optional[ContextGraph] = None, ) -> List[List[int]]: """Implement prefix search decoding in "Connectionist Temporal Classification: @@ -1981,17 +1983,16 @@ def ctc_prefix_beam_search_shallow_fussion( encoder_out_lens = encoder_out_lens.tolist() device = ctc_output.device - lm_scale = 0 + nnlm_scale = 0 init_scores = None init_states = None - - if LM is not None: - lm_scale = LM.lm_scale - sos_id = getattr(LM, "sos_id", 1) + if NNLM is not None: + nnlm_scale = NNLM.lm_scale + sos_id = getattr(NNLM, "sos_id", 1) # get initial lm score and lm state by scoring the "sos" token sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) lens = torch.tensor([1]).to(device) - init_scores, init_states = LM.score_token(sos_token, lens) + init_scores, init_states = NNLM.score_token(sos_token, lens) init_scores, init_states = init_scores.cpu(), ( init_states[0].cpu(), init_states[1].cpu(), @@ -2016,16 +2017,16 @@ def ctc_prefix_beam_search_shallow_fussion( if j < encoder_out_lens[i]: log_probs, indexes = topk_values[i][j], topk_indexes[i][j] B[i] = _step_worker( - log_probs, - indexes, - B[i], - beam, - blank_id, - lm_scale=lm_scale, + log_probs=log_probs, + indexes=indexes, + B=B[i], + beam=beam, + blank_id=blank_id, + nnlm_scale=nnlm_scale, LODR_lm_scale=LODR_lm_scale, context_graph=context_graph, ) - if LM is None: + if NNLM is None: continue # update lm_log_probs token_list = [] # a list of list @@ -2035,7 +2036,7 @@ def ctc_prefix_beam_search_shallow_fussion( for batch_idx, hyps in enumerate(B): for hyp in hyps: if hyp.lm_log_probs is None: # those hyps that prefix changes - if LM.lm_type == "rnn": + if NNLM.lm_type == "rnn": token_list.append([hyp.ys[-1]]) # store the LSTM states hs.append(hyp.state[0]) @@ -2046,7 +2047,7 @@ def ctc_prefix_beam_search_shallow_fussion( indexes.append((batch_idx, hyp.key)) if len(token_list) != 0: x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device) - if LM.lm_type == "rnn": + if NNLM.lm_type == "rnn": tokens_to_score = ( torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) ) @@ -2065,13 +2066,13 @@ def ctc_prefix_beam_search_shallow_fussion( ) state = None - scores, lm_states = LM.score_token(tokens_to_score, x_lens, state) + scores, lm_states = NNLM.score_token(tokens_to_score, x_lens, state) scores, lm_states = scores.cpu(), (lm_states[0].cpu(), lm_states[1].cpu()) assert scores.size(0) == len(indexes), (scores.size(0), len(indexes)) for i in range(scores.size(0)): batch_idx, key = indexes[i] B[batch_idx][key].lm_log_probs = scores[i] - if LM.lm_type == "rnn": + if NNLM.lm_type == "rnn": state = ( lm_states[0][:, i, :].unsqueeze(1), lm_states[1][:, i, :].unsqueeze(1),