diff --git a/egs/librispeech/ASR/zipformer/decode.py b/egs/librispeech/ASR/zipformer/decode.py index cbfb3728e6..52a489eb31 100755 --- a/egs/librispeech/ASR/zipformer/decode.py +++ b/egs/librispeech/ASR/zipformer/decode.py @@ -122,6 +122,7 @@ modified_beam_search_LODR, ) from lhotse import set_caching_enabled +from tokenizer import Tokenizer from train import add_model_arguments, get_model, get_params from icefall import ContextGraph, LmScorer, NgramLm @@ -377,6 +378,17 @@ def get_parser(): default=False, help="""Skip scoring, but still save the ASR output (for eval sets).""", ) + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) add_model_arguments(parser) @@ -601,6 +613,7 @@ def decode_one_batch( # prefix = ( "greedy_search" | "fast_beam_search_nbest" | "modified_beam_search" ) prefix = f"{params.decoding_method}" + key = f"blank_penalty_{params.blank_penalty}" if params.decoding_method == "greedy_search": return {"greedy_search": hyps} elif "fast_beam_search" in params.decoding_method: diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 69059287b4..d42a5b145c 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -2434,4 +2434,4 @@ def _test_zipformer_main(causal: bool = False): torch.set_num_threads(1) torch.set_num_interop_threads(1) _test_zipformer_main(False) - _test_zipformer_main(True) + _test_zipformer_main(True) \ No newline at end of file diff --git a/egs/reazonspeech/ASR/RESULTS.md b/egs/reazonspeech/ASR/RESULTS.md index c0b4fe54a7..92610d75bb 100644 --- a/egs/reazonspeech/ASR/RESULTS.md +++ b/egs/reazonspeech/ASR/RESULTS.md @@ -47,3 +47,41 @@ The decoding command is: --blank-penalty 0 ``` +#### Streaming + +We have not completed evaluation of our models yet and will add evaluation results here once it's completed. + +The training command is: +```shell +./zipformer/train.py \ + --world-size 8 \ + --num-epochs 40 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp-large \ + --causal 1 \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 \ + --lang data/lang_char \ + --max-duration 1600 +``` + +The decoding command is: + +```shell +./zipformer/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --causal 1 \ + --chunk-size 32 \ + --left-context-frames 256 \ + --exp-dir ./zipformer/exp-large \ + --lang data/lang_char \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 +``` + diff --git a/egs/reazonspeech/ASR/local/utils/tokenizer.py b/egs/reazonspeech/ASR/local/utils/tokenizer.py index c9be72be10..ba71cff893 100644 --- a/egs/reazonspeech/ASR/local/utils/tokenizer.py +++ b/egs/reazonspeech/ASR/local/utils/tokenizer.py @@ -12,7 +12,6 @@ class Tokenizer: @staticmethod def add_arguments(parser: argparse.ArgumentParser): group = parser.add_argument_group(title="Lang related options") - group.add_argument("--lang", type=Path, help="Path to lang directory.") group.add_argument( diff --git a/egs/reazonspeech/ASR/zipformer/streaming_decode.py b/egs/reazonspeech/ASR/zipformer/streaming_decode.py index 4c18c75634..9274f4dc4f 100755 --- a/egs/reazonspeech/ASR/zipformer/streaming_decode.py +++ b/egs/reazonspeech/ASR/zipformer/streaming_decode.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang) -# +# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang, +# Fangjun Kuang, +# Zengwei Yao) # See ../../../../LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,28 +18,24 @@ """ Usage: -./pruned_transducer_stateless7_streaming/streaming_decode.py \ - --epoch 28 \ - --avg 15 \ - --decode-chunk-len 32 \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --decoding_method greedy_search \ - --lang data/lang_char \ - --num-decode-streams 2000 +./zipformer/streaming_decode.py--epoch 28 --avg 15 --causal 1 --chunk-size 32 --left-context-frames 256 --exp-dir ./zipformer/exp-large --lang data/lang_char --num-encoder-layers 2,2,4,5,4,2 --feedforward-dim 512,768,1536,2048,1536,768 --encoder-dim 192,256,512,768,512,256 --encoder-unmasked-dim 192,192,256,320,256,192 + """ import argparse import logging import math +import os +import pdb + +# import subprocess as sp from pathlib import Path from typing import Dict, List, Optional, Tuple import k2 import numpy as np import torch -import torch.nn as nn from asr_datamodule import ReazonSpeechAsrDataModule -from decode import save_results from decode_stream import DecodeStream from kaldifeat import Fbank, FbankOptions from lhotse import CutSet @@ -48,9 +45,9 @@ modified_beam_search, ) from tokenizer import Tokenizer +from torch import Tensor, nn from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model -from zipformer import stack_states, unstack_states +from train import add_model_arguments, get_model, get_params from icefall.checkpoint import ( average_checkpoints, @@ -58,7 +55,14 @@ find_checkpoints, load_checkpoint, ) -from icefall.utils import AttributeDict, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) LOG_EPS = math.log(1e-10) @@ -73,7 +77,7 @@ def get_parser(): type=int, default=28, help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 0. + Note: Epoch counts from 1. You can specify --avg to use more checkpoints for model averaging.""", ) @@ -87,12 +91,6 @@ def get_parser(): """, ) - parser.add_argument( - "--gpu", - type=int, - default=0, - ) - parser.add_argument( "--avg", type=int, @@ -116,7 +114,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless2/exp", + default="zipformer/exp", help="The experiment dir", ) @@ -127,6 +125,13 @@ def get_parser(): help="Path to the BPE model", ) + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_char", + help="The lang dir containing word table and LG graph", + ) + parser.add_argument( "--decoding-method", type=str, @@ -138,14 +143,6 @@ def get_parser(): """, ) - parser.add_argument( - "--decoding-graph", - type=str, - default="", - help="""Used only when --decoding-method is - fast_beam_search""", - ) - parser.add_argument( "--num_active_paths", type=int, @@ -157,7 +154,7 @@ def get_parser(): parser.add_argument( "--beam", type=float, - default=4.0, + default=4, help="""A floating point value to calculate the cutoff score during beam search (i.e., `cutoff = max-score - beam`), which is the same as the `beam` in Kaldi. @@ -194,18 +191,235 @@ def get_parser(): help="The number of streams that can be decoded parallel.", ) - parser.add_argument( - "--res-dir", - type=Path, - default=None, - help="The path to save results.", - ) - add_model_arguments(parser) return parser +def get_init_states( + model: nn.Module, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), +) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + """ + states = model.encoder.get_init_states(batch_size, device) + + embed_states = model.encoder_embed.get_init_states(batch_size, device) + states.append(embed_states) + + processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) + states.append(processed_lens) + + return states + + +def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: + """Stack list of zipformer states that correspond to separate utterances + into a single emformer state, so that it can be used as an input for + zipformer when those utterances are formed into a batch. + + Args: + state_list: + Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. For element-n, + state_list[n] is a list of cached tensors of all encoder layers. For layer-i, + state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, + cached_val2, cached_conv1, cached_conv2). + state_list[n][-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + state_list[n][-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Note: + It is the inverse of :func:`unstack_states`. + """ + batch_size = len(state_list) + assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0]) + tot_num_layers = (len(state_list[0]) - 2) // 6 + + batch_states = [] + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key = torch.cat( + [state_list[i][layer_offset] for i in range(batch_size)], dim=1 + ) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn = torch.cat( + [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1 = torch.cat( + [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2 = torch.cat( + [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1 = torch.cat( + [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2 = torch.cat( + [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0 + ) + batch_states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + + cached_embed_left_pad = torch.cat( + [state_list[i][-2] for i in range(batch_size)], dim=0 + ) + batch_states.append(cached_embed_left_pad) + + processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) + batch_states.append(processed_lens) + + return batch_states + + +def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: + """Unstack the zipformer state corresponding to a batch of utterances + into a list of states, where the i-th entry is the state from the i-th + utterance in the batch. + + Note: + It is the inverse of :func:`stack_states`. + + Args: + batch_states: A list of cached tensors of all encoder layers. For layer-i, + states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, + cached_conv1, cached_conv2). + state_list[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Returns: + state_list: A list of list. Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. + """ + assert (len(batch_states) - 2) % 6 == 0, len(batch_states) + tot_num_layers = (len(batch_states) - 2) // 6 + + processed_lens = batch_states[-1] + batch_size = processed_lens.shape[0] + + state_list = [[] for _ in range(batch_size)] + + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( + chunks=batch_size, dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1_list = batch_states[layer_offset + 2].chunk( + chunks=batch_size, dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2_list = batch_states[layer_offset + 3].chunk( + chunks=batch_size, dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1_list = batch_states[layer_offset + 4].chunk( + chunks=batch_size, dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2_list = batch_states[layer_offset + 5].chunk( + chunks=batch_size, dim=0 + ) + for i in range(batch_size): + state_list[i] += [ + cached_key_list[i], + cached_nonlin_attn_list[i], + cached_val1_list[i], + cached_val2_list[i], + cached_conv1_list[i], + cached_conv2_list[i], + ] + + cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(cached_embed_left_pad_list[i]) + + processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(processed_lens_list[i]) + + return state_list + + +def streaming_forward( + features: Tensor, + feature_lens: Tensor, + model: nn.Module, + states: List[Tensor], + chunk_size: int, + left_context_len: int, +) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Returns encoder outputs, output lengths, and updated states. + """ + cached_embed_left_pad = states[-2] + (x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward( + x=features, + x_lens=feature_lens, + cached_left_pad=cached_embed_left_pad, + ) + assert x.size(1) == chunk_size, (x.size(1), chunk_size) + + src_key_padding_mask = make_pad_mask(x_lens) + + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + new_processed_lens = processed_lens + x_lens + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + encoder_states = states[:-2] + ( + encoder_out, + encoder_out_lens, + new_encoder_states, + ) = model.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=encoder_states, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = new_encoder_states + [ + new_cached_embed_left_pad, + new_processed_lens, + ] + return encoder_out, encoder_out_lens, new_states + + def decode_one_chunk( params: AttributeDict, model: nn.Module, @@ -224,27 +438,32 @@ def decode_one_chunk( Returns: Return a List containing which DecodeStreams are finished. """ - device = model.device + # pdb.set_trace() + # print(model) + # print(model.device) + # device = model.device + chunk_size = int(params.chunk_size) + left_context_len = int(params.left_context_frames) features = [] feature_lens = [] states = [] - processed_lens = [] + processed_lens = [] # Used in fast-beam-search for stream in decode_streams: - feat, feat_len = stream.get_feature_frames(params.decode_chunk_len) + feat, feat_len = stream.get_feature_frames(chunk_size * 2) features.append(feat) feature_lens.append(feat_len) states.append(stream.states) processed_lens.append(stream.done_frames) - feature_lens = torch.tensor(feature_lens, device=device) + feature_lens = torch.tensor(feature_lens, device=model.device) features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) - # We subsample features with ((x_len - 7) // 2 + 1) // 2 and the max downsampling - # factor in encoders is 8. - # After feature embedding (x_len - 7) // 2, we have (23 - 7) // 2 = 8. - tail_length = 23 + # Make sure the length after encoder_embed is at least 1. + # The encoder_embed subsample features (T - 7) // 2 + # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling + tail_length = chunk_size * 2 + 7 + 2 * 3 if features.size(1) < tail_length: pad_length = tail_length - features.size(1) feature_lens += pad_length @@ -256,12 +475,14 @@ def decode_one_chunk( ) states = stack_states(states) - processed_lens = torch.tensor(processed_lens, device=device) - encoder_out, encoder_out_lens, new_states = model.encoder.streaming_forward( - x=features, - x_lens=feature_lens, + encoder_out, encoder_out_lens, new_states = streaming_forward( + features=features, + feature_lens=feature_lens, + model=model, states=states, + chunk_size=chunk_size, + left_context_len=left_context_len, ) encoder_out = model.joiner.encoder_proj(encoder_out) @@ -269,6 +490,7 @@ def decode_one_chunk( if params.decoding_method == "greedy_search": greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": + processed_lens = torch.tensor(processed_lens, device=model.device) processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( model=model, @@ -295,8 +517,9 @@ def decode_one_chunk( for i in range(len(decode_streams)): decode_streams[i].states = states[i] decode_streams[i].done_frames += encoder_out_lens[i] - if decode_streams[i].done: - finished_streams.append(i) + # if decode_streams[i].done: + # finished_streams.append(i) + finished_streams.append(i) return finished_streams @@ -338,14 +561,14 @@ def decode_dataset( opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 - log_interval = 50 + log_interval = 100 decode_results = [] # Contain decode streams currently running. decode_streams = [] for num, cut in enumerate(cuts): # each utterance has a DecodeStream. - initial_states = model.encoder.get_init_state(device=device) + initial_states = get_init_states(model=model, batch_size=1, device=device) decode_stream = DecodeStream( params=params, cut_id=cut.id, @@ -361,15 +584,19 @@ def decode_dataset( assert audio.dtype == np.float32, audio.dtype # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." samples = torch.from_numpy(audio).squeeze(0) fbank = Fbank(opts) feature = fbank(samples.to(device)) - decode_stream.set_features(feature, tail_pad_len=params.decode_chunk_len) - decode_stream.ground_truth = cut.supervisions[0].custom[params.transcript_mode] - + decode_stream.set_features(feature, tail_pad_len=30) + decode_stream.ground_truth = cut.supervisions[0].text decode_streams.append(decode_stream) while len(decode_streams) >= params.num_decode_streams: @@ -380,8 +607,8 @@ def decode_dataset( decode_results.append( ( decode_streams[i].id, - sp.text2word(decode_streams[i].ground_truth), - sp.text2word(sp.decode(decode_streams[i].decoding_result())), + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), ) ) del decode_streams[i] @@ -391,18 +618,37 @@ def decode_dataset( # decode final chunks of last sequences while len(decode_streams): + # print("INSIDE LEN DECODE STREAMS") + # pdb.set_trace() + # print(model.device) + # test_device = model.device + # print("done") finished_streams = decode_one_chunk( params=params, model=model, decode_streams=decode_streams ) + # print('INSIDE FOR LOOP ') + # print(finished_streams) + + if not finished_streams: + print("No finished streams, breaking the loop") + break + for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - sp.text2word(decode_streams[i].ground_truth), - sp.text2word(sp.decode(decode_streams[i].decoding_result())), + try: + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) ) - ) - del decode_streams[i] + del decode_streams[i] + except IndexError as e: + print(f"IndexError: {e}") + print(f"decode_streams length: {len(decode_streams)}") + print(f"finished_streams: {finished_streams}") + print(f"i: {i}") + continue if params.decoding_method == "greedy_search": key = "greedy_search" @@ -416,9 +662,54 @@ def decode_dataset( key = f"num_active_paths_{params.num_active_paths}" else: raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + torch.cuda.synchronize() return {key: decode_results} +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + @torch.no_grad() def main(): parser = get_parser() @@ -430,16 +721,20 @@ def main(): params = get_params() params.update(vars(args)) - if not params.res_dir: - params.res_dir = params.exp_dir / "streaming" / params.decoding_method + params.res_dir = params.exp_dir / "streaming" / params.decoding_method if params.iter > 0: params.suffix = f"iter-{params.iter}-avg-{params.avg}" else: params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - # for streaming - params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}" + assert params.causal, params.causal + assert "," not in params.chunk_size, "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" # for fast_beam_search if params.decoding_method == "fast_beam_search": @@ -455,13 +750,13 @@ def main(): device = torch.device("cpu") if torch.cuda.is_available(): - device = torch.device("cuda", params.gpu) + device = torch.device("cuda", 0) logging.info(f"Device: {device}") sp = Tokenizer.load(params.lang, params.lang_type) - # and is defined in local/prepare_lang_char.py + # and is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() @@ -469,7 +764,7 @@ def main(): logging.info(params) logging.info("About to create model") - model = get_transducer_model(params) + model = get_model(params) if not params.use_averaged_model: if params.iter > 0: @@ -553,42 +848,51 @@ def main(): model.device = device decoding_graph = None - if params.decoding_graph: - decoding_graph = k2.Fsa.from_dict( - torch.load(params.decoding_graph, map_location=device) - ) - elif params.decoding_method == "fast_beam_search": + if params.decoding_method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + # we need cut ids to display recognition results. args.return_cuts = True reazonspeech_corpus = ReazonSpeechAsrDataModule(args) - for subdir in ["valid"]: + valid_cuts = reazonspeech_corpus.valid_cuts() + test_cuts = reazonspeech_corpus.test_cuts() + + test_sets = ["valid", "test"] + test_cuts = [valid_cuts, test_cuts] + + for test_set, test_cut in zip(test_sets, test_cuts): results_dict = decode_dataset( - cuts=getattr(reazonspeech_corpus, f"{subdir}_cuts")(), + cuts=test_cut, params=params, model=model, sp=sp, decoding_graph=decoding_graph, ) - tot_err = save_results( - params=params, test_set_name=subdir, results_dict=results_dict + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, ) - with ( - params.res_dir - / ( - f"{subdir}-{params.decode_chunk_len}" - f"_{params.avg}_{params.epoch}.cer" - ) - ).open("w") as fout: - if len(tot_err) == 1: - fout.write(f"{tot_err[0][1]}") - else: - fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err) + # valid_cuts = reazonspeech_corpus.valid_cuts() + + # for valid_cut in valid_cuts: + # results_dict = decode_dataset( + # cuts=valid_cut, + # params=params, + # model=model, + # sp=sp, + # decoding_graph=decoding_graph, + # ) + # save_results( + # params=params, + # test_set_name="valid", + # results_dict=results_dict, + # ) logging.info("Done!") diff --git a/icefall/utils.py b/icefall/utils.py index 1dbb954ded..9a25784cb1 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -631,7 +631,8 @@ def write_error_stats( results[i] = (cut_id, ref, hyp) for cut_id, ref, hyp in results: - ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode) + # ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode) + ali = kaldialign.align(ref, hyp, ERR) for ref_word, hyp_word in ali: if ref_word == ERR: ins[hyp_word] += 1