From 7dc88cd74ca1ae98481788b0efb5ec3c7b36cbe5 Mon Sep 17 00:00:00 2001 From: fboyer Date: Thu, 8 Aug 2019 16:24:00 +0200 Subject: [PATCH] RNN-T and RNN-T-Att for ESPNET v0.5.0 --- .dockerignore | 1 + .gitignore | 1 + .../asr1/conf/tuning/decode_rnnt.yaml | 4 + egs/voxforge/asr1/conf/tuning/train_rnnt.yaml | 46 + .../asr1/conf/tuning/train_rnnt_att.yaml | 43 + egs/voxforge/asr1/run.sh | 7 +- espnet/asr/asr_rnnt_utils.py | 152 ++++ espnet/asr/pytorch_backend/asr_rnnt.py | 793 ++++++++++++++++++ espnet/bin/asr_rnnt_recog.py | 130 +++ espnet/bin/asr_rnnt_train.py | 348 ++++++++ espnet/nets/pytorch_backend/e2e_asr_rnnt.py | 396 +++++++++ .../nets/pytorch_backend/rnn/decoders_rnnt.py | 635 ++++++++++++++ tools/Makefile | 17 +- tools/check_install.py | 15 +- 14 files changed, 2576 insertions(+), 12 deletions(-) create mode 100644 egs/voxforge/asr1/conf/tuning/decode_rnnt.yaml create mode 100644 egs/voxforge/asr1/conf/tuning/train_rnnt.yaml create mode 100644 egs/voxforge/asr1/conf/tuning/train_rnnt_att.yaml create mode 100644 espnet/asr/asr_rnnt_utils.py create mode 100644 espnet/asr/pytorch_backend/asr_rnnt.py create mode 100755 espnet/bin/asr_rnnt_recog.py create mode 100755 espnet/bin/asr_rnnt_train.py create mode 100644 espnet/nets/pytorch_backend/e2e_asr_rnnt.py create mode 100644 espnet/nets/pytorch_backend/rnn/decoders_rnnt.py diff --git a/.dockerignore b/.dockerignore index b0c0f0f0fdd..0f09f541bb8 100644 --- a/.dockerignore +++ b/.dockerignore @@ -10,6 +10,7 @@ tools/miniconda.sh tools/nkf/ tools/venv/ tools/warp-ctc/ +tools/warp-transducer/ tools/chainer_ctc/ tools/subword-nmt/ diff --git a/.gitignore b/.gitignore index f25eeba5c7f..414b9e9d966 100644 --- a/.gitignore +++ b/.gitignore @@ -46,5 +46,6 @@ tools/venv/ tools/sentencepiece/ tools/swig/ tools/warp-ctc/ +tools/warp-transducer/ tools/*.done tools/PESQ* diff --git a/egs/voxforge/asr1/conf/tuning/decode_rnnt.yaml b/egs/voxforge/asr1/conf/tuning/decode_rnnt.yaml new file mode 100644 index 00000000000..80c3b71a364 --- /dev/null +++ b/egs/voxforge/asr1/conf/tuning/decode_rnnt.yaml @@ -0,0 +1,4 @@ +# decoding parameters +beam-size: 20 +search-type: beam +score-norm: True \ No newline at end of file diff --git a/egs/voxforge/asr1/conf/tuning/train_rnnt.yaml b/egs/voxforge/asr1/conf/tuning/train_rnnt.yaml new file mode 100644 index 00000000000..f3e89256268 --- /dev/null +++ b/egs/voxforge/asr1/conf/tuning/train_rnnt.yaml @@ -0,0 +1,46 @@ +# minibatch related +batch-size: 20 +maxlen-in: 800 +maxlen-out: 150 + +# optimization related +sortagrad: 0 +opt: adadelta +epochs: 20 +patience: 3 + +# network architecture +## encoder related +etype: vggblstmp +elayers: 4 +eunits: 320 +eprojs: 320 +## decoder related +dtype: lstm +dlayers: 1 +dunits: 320 +dec-embed-dim: 320 +## joint network related +joint-dim: 320 + +# rnn-t related (0:rnnt, 1:rnnt-att) +rnnt-mode: 0 + +# finetuning related +## Note : Current implementation only allow to do pre-initialization with models +## matching the configuration specified above. The architecture you specify +## should match the modules architecture you want to do finetuning with. +## For example, if you want to pre-initialize the decoder embedding layer +## from a specified model, "dec-embed-dim" param should be set accordingly. + +# following model correspond to conf/tuning/train_mtlalpha1.0.yaml +#enc-init: "exp/tr_it_pytorch_train_mtlalpha1.0/results/model.loss.best" +#enc-init-mods: "enc.enc." + +# following model is a CE-trained RNNLM similar to the one in: +# egs/librispeech/asr1/conf/tuning/lm.yaml +#dec-init: "exp/train_rnnlm_pytorch/rnnlm.model.best" +#dec-init-mods: "predictor.rnn.,predictor.embed." + +# freeze modules +#freeze-modules: "predictor.rnn." \ No newline at end of file diff --git a/egs/voxforge/asr1/conf/tuning/train_rnnt_att.yaml b/egs/voxforge/asr1/conf/tuning/train_rnnt_att.yaml new file mode 100644 index 00000000000..59976f8e7eb --- /dev/null +++ b/egs/voxforge/asr1/conf/tuning/train_rnnt_att.yaml @@ -0,0 +1,43 @@ +# minibatch related +batch-size: 20 +maxlen-in: 800 +maxlen-out: 150 + +# optimization related +sortagrad: 0 +opt: adadelta +epochs: 20 +patience: 3 + +# network architecture +## encoder related +etype: vggblstmp +elayers: 4 +eunits: 320 +eprojs: 320 +## decoder related +dlayers: 1 +dunits: 300 +## attention related +atype: location +adim: 320 +aconv-chans: 10 +aconv-filts: 100 +## joint network related +joint-dim: 320 + +# rnn-t related (0:rnnt, 1:rnnt-att) +rnnt-mode: 1 + +# finetuning related +## Note : Current implementation only allow to do pre-initialization with models +## matching the configuration specified above. The architecture you specify +## should match the modules architecture you want to do finetuning with. +## For example, if you want to pre-initialize the decoder embedding layer +## from a specified model, "dec-embed-dim" param should be set accordingly. + +# following model correspond to conf/tuning/train_mtlalpha1.0.yaml +#enc-init: "exp/tr_it_pytorch_train_mtlalpha1.0/results/model.loss.best" + +# following model correspond to conf/tuning/train_mtlalpha0.0.yaml +#dec-init: "exp/tr_it_pytorch_train_mtlalpha0.0/results/model.loss.best" \ No newline at end of file diff --git a/egs/voxforge/asr1/run.sh b/egs/voxforge/asr1/run.sh index 9ccfe5ad48f..e5c27914b38 100755 --- a/egs/voxforge/asr1/run.sh +++ b/egs/voxforge/asr1/run.sh @@ -3,8 +3,8 @@ # Copyright 2017 Johns Hopkins University (Shinji Watanabe) # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) -. ./path.sh -. ./cmd.sh +. ./path.sh || exit 1; +. ./cmd.sh || exit 1; # general configuration backend=pytorch @@ -36,9 +36,6 @@ tag="" # tag for managing experiments. . utils/parse_options.sh || exit 1; -. ./path.sh -. ./cmd.sh - # Set bash to 'debug' mode, it will exit on : # -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', set -e diff --git a/espnet/asr/asr_rnnt_utils.py b/espnet/asr/asr_rnnt_utils.py new file mode 100644 index 00000000000..55bd67d6db3 --- /dev/null +++ b/espnet/asr/asr_rnnt_utils.py @@ -0,0 +1,152 @@ +import os +import logging +import torch + +from collections import OrderedDict + +# Note : the following methods are just toy examples for finetuning RNN-T and RNN-T att. +# It's quite inelegant as if but it handle most of the cases. + +def load_pretrained_modules(model, rnnt_mode, enc_pretrained, dec_pretrained, + enc_mods, dec_mods): + """Method to update model modules weights from up to two ESPNET pre-trained models. + Specified models can be either trained with CTC, attention or joint CTC-attention, + or a language model trained with CE to initialize the decoder part in vanilla RNN-T. + + Args: + model (torch.nn.Module): initial torch model + rnnt_mode (int): RNN-transducer mode + enc_pretrained (str): ESPNET pre-trained model path file to initialize encoder part + dec_pretrained (str): ESPNET pre-trained model path file to initialize decoder part + """ + + def filter_modules(model, modules): + new_mods = [] + incorrect_mods = [] + + mods_model = list(model.keys()) + for mod in modules: + if any(key.startswith(mod) for key in mods_model): + new_mods += [mod] + else: + incorrect_mods += [mod] + + if incorrect_mods: + logging.info("Some specified module(s) for finetuning don\'t " + "match (or partially match) available modules." + " Disabling the following modules: %s", incorrect_mods) + logging.info('For information, the existing modules in model are:') + logging.info('%s', mods_model) + + return new_mods + + def validate_modules(model, prt, modules): + modules_model = [] + modules_prt = [] + + for key_p, value_p in prt.items(): + if any(key_p.startswith(m) for m in modules): + modules_prt += [(key_p, value_p.shape)] + + for key_m, value_m in model.items(): + if any(key_m.startswith(m) for m in modules): + modules_model += [(key_m, value_m.shape)] + + len_match = (len(modules_model) == len(modules_prt)) + module_match = (sorted([x for x in modules_model]) == \ + sorted([x for x in modules_prt])) + + return len_match and module_match + + def get_am_state_dict(model, modules): + new_state_dict = OrderedDict() + + for key, value in model.items(): + if any(key.startswith(m) for m in modules): + if not 'output' in key: + new_state_dict[key] = value + + return new_state_dict + + def get_lm_state_dict(model, modules): + new_state_dict = OrderedDict() + new_modules = [] + + for key, value in list(model.items()): + if key == "predictor.embed.weight" \ + and "predictor.embed." in modules: + new_key = "dec.embed.weight" + new_state_dict[new_key] = value + new_modules += [new_key] + elif "predictor.rnn." in key \ + and "predictor.rnn." in modules: + new_key = "dec.decoder." + key.split("predictor.rnn.",1)[1] + new_state_dict[new_key] = value + new_modules += [new_key] + + return new_state_dict, new_modules + + model_state_dict = model.state_dict() + + if rnnt_mode == 0 and dec_pretrained is not None: + lm_init = True + else: + lm_init = False + + for prt_model_path, prt_mods, prt_type in [(enc_pretrained, enc_mods, False), + (dec_pretrained, dec_mods, lm_init)]: + if prt_model_path is not None: + if os.path.isfile(prt_model_path): + prt_model = torch.load(prt_model_path, + map_location=lambda storage, loc: storage) + + prt_mods = filter_modules(prt_model, prt_mods) + if prt_type: + prt_state_dict, prt_mods = get_lm_state_dict(prt_model, prt_mods) + else: + prt_state_dict = get_am_state_dict(prt_model, prt_mods) + + if prt_state_dict: + if validate_modules(model_state_dict, prt_state_dict, prt_mods): + model_state_dict.update(prt_state_dict) + else: + logging.info("The model you specified for pre-initialization " + "doesn\'t match your run config. It will be ignored") + logging.info('Model path: %s' % prt_model_path) + else: + logging.info('The model you specified for pre-initialization was not found.') + logging.info('Model path: %s' % prt_model_path) + + model.load_state_dict(model_state_dict) + + del model_state_dict + + return model + +def freeze_modules(model, modules): + """Method to freeze specified list of modules. + + Args: + model (torch.nn.Module): torch model + modules (str): list of module names to freeze during training + + Returns: + (boolean): if True, filter the specified modules in the optimizer + """ + + mods_model = [name for name, _ in model.named_parameters()] + + if not any(i in j for j in mods_model for i in modules): + logging.info("Some module(s) you specified to freeze don\'t " + "match (or partially match) available modules.") + logging.info("Disabling the option.") + logging.info('For information, the existing modules in model are:') + logging.info('%s', mods_model) + + return False + + for name, param in model.named_parameters(): + if any(name.startswith(m) for m in modules): + param.requires_grad = False + + return True diff --git a/espnet/asr/pytorch_backend/asr_rnnt.py b/espnet/asr/pytorch_backend/asr_rnnt.py new file mode 100644 index 00000000000..8fca5f5441e --- /dev/null +++ b/espnet/asr/pytorch_backend/asr_rnnt.py @@ -0,0 +1,793 @@ +#!/usr/bin/env python + +import copy +import json +import logging +import math +import os +import sys +import torch +import numpy as np + +from chainer.datasets import TransformDataset +from chainer import reporter as reporter_module +from chainer import training +from chainer.training import extensions + +from tensorboardX import SummaryWriter + +from espnet.asr.asr_utils import adadelta_eps_decay +from espnet.asr.asr_utils import add_results_to_json +from espnet.asr.asr_utils import CompareValueTrigger +from espnet.asr.asr_utils import get_model_conf +from espnet.asr.asr_utils import plot_spectrogram +from espnet.asr.asr_utils import restore_snapshot +from espnet.asr.asr_utils import snapshot_object +from espnet.asr.asr_utils import torch_load +from espnet.asr.asr_utils import torch_resume +from espnet.asr.asr_utils import torch_snapshot + +from espnet.asr.asr_rnnt_utils import load_pretrained_modules +from espnet.asr.asr_rnnt_utils import freeze_modules + +import espnet.lm.pytorch_backend.extlm as extlm_pytorch +import espnet.lm.pytorch_backend.lm as lm_pytorch + +from espnet.nets.asr_interface import ASRInterface +from espnet.nets.pytorch_backend.nets_utils import pad_list + +from espnet.transform.spectrogram import IStft +from espnet.transform.transformation import Transformation + +from espnet.utils.cli_writers import file_writer_helper +from espnet.utils.deterministic_utils import set_deterministic_pytorch +from espnet.utils.dynamic_import import dynamic_import +from espnet.utils.io_utils import LoadInputsAndTargets +from espnet.utils.training.batchfy import make_batchset +from espnet.utils.training.iterators import ShufflingEnabler +from espnet.utils.training.iterators import ToggleableShufflingMultiprocessIterator +from espnet.utils.training.iterators import ToggleableShufflingSerialIterator +from espnet.utils.training.tensorboard_logger import TensorboardLogger +from espnet.utils.training.train_utils import check_early_stop +from espnet.utils.training.train_utils import set_early_stop + +import matplotlib +matplotlib.use('Agg') + +if sys.version_info[0] == 2: + from itertools import izip_longest as zip_longest +else: + from itertools import zip_longest as zip_longest + +REPORT_INTERVAL = 100 + + +class CustomEvaluator(extensions.Evaluator): + """Custom Evaluator for Pytorch. + + Args: + model (torch.nn.Module): The model to evaluate. + iterator (chainer.dataset.Iterator) : The train iterator. + + target (link | dict[str, link]) :Link object or a dictionary of + links to evaluate. If this is just a link object, the link is + registered by the name ``'main'``. + converter (espnet.asr.pytorch_backend.asr.CustomConverter): Converter + function to build input arrays. Each batch extracted by the main + iterator and the ``device`` option are passed to this function. + :func:`chainer.dataset.concat_examples` is used by default. + + device (torch.device): The device used. + """ + + def __init__(self, model, iterator, target, converter, device): + super(CustomEvaluator, self).__init__(iterator, target) + self.model = model + self.converter = converter + self.device = device + + # The core part of the update routine can be customized by overriding + def evaluate(self): + iterator = self._iterators['main'] + + if self.eval_hook: + self.eval_hook(self) + + if hasattr(iterator, 'reset'): + iterator.reset() + it = iterator + else: + it = copy.copy(iterator) + + summary = reporter_module.DictSummary() + + self.model.eval() + with torch.no_grad(): + for batch in it: + observation = {} + with reporter_module.report_scope(observation): + # read scp files + # x: original json with loaded features + # will be converted to chainer variable later + x = self.converter(batch, self.device) + self.model(*x) + summary.add(observation) + self.model.train() + + return summary.compute_mean() + + +class CustomUpdater(training.StandardUpdater): + """Custom Updater for Pytorch. + + Args: + model (torch.nn.Module): The model to update. + grad_clip_threshold (int): The gradient clipping value to use. + train_iter (chainer.dataset.Iterator): The training iterator. + optimizer (torch.optim.optimizer): The training optimizer. + + converter (espnet.asr.pytorch_backend.asr.CustomConverter): Converter + function to build input arrays. Each batch extracted by the main + iterator and the ``device`` option are passed to this function. + :func:`chainer.dataset.concat_examples` is used by default. + + device (torch.device): The device to use. + ngpu (int): The number of gpus to use. + """ + + def __init__(self, model, grad_clip_threshold, train_iter, + optimizer, converter, device, ngpu, grad_noise=False, accum_grad=1): + super(CustomUpdater, self).__init__(train_iter, optimizer) + self.model = model + self.grad_clip_threshold = grad_clip_threshold + self.converter = converter + self.device = device + self.ngpu = ngpu + self.accum_grad = accum_grad + self.forward_count = 0 + self.grad_noise = grad_noise + self.iteration = 0 + + # The core part of the update routine can be customized by overriding. + def update_core(self): + # When we pass one iterator and optimizer to StandardUpdater.__init__, + # they are automatically named 'main'. + train_iter = self.get_iterator('main') + optimizer = self.get_optimizer('main') + + # Get the next batch ( a list of json files) + batch = train_iter.next() + self.iteration += 1 + x = self.converter(batch, self.device) + + # Compute the loss at this time step and accumulate it + loss = self.model(*x).mean() / self.accum_grad + loss.backward() # Backprop + #gradient noise injection + if self.grad_noise: + from espnet.asr.asr_utils import add_gradient_noise + add_gradient_noise(self.model, self.iteration, + duration=100, eta=1.0, scale_factor=0.55) + loss.detach() # Truncate the graph + + # update parameters + self.forward_count += 1 + if self.forward_count != self.accum_grad: + return + self.forward_count = 0 + + # compute the gradient norm to check if it is normal or not + grad_norm = torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.grad_clip_threshold) + logging.info('grad norm={}'.format(grad_norm)) + if math.isnan(grad_norm): + logging.warning('grad norm is nan. Do not update model.') + else: + optimizer.step() + optimizer.zero_grad() + + def update(self): + self.update_core() + # #iterations with accum_grad > 1 + # Ref.: https://github.com/espnet/espnet/issues/777 + if self.forward_count == 0: + self.iteration += 1 + +class CustomConverter(object): + """Custom batch converter for Pytorch + + Args: + subsampling_factor (int): The subsampling factor + """ + + def __init__(self, subsampling_factor=1): + self.subsampling_factor = subsampling_factor + self.ignore_id = -1 + + def __call__(self, batch, device): + """Transforms a batch and send it to a device. + + Args: + batch (list): The batch to transform. + device (torch.device): The device to send to. + + Returns: + tuple(torch.Tensor, torch.Tensor, torch.Tensor) + """ + + # batch should be located in list + assert len(batch) == 1 + xs, ys = batch[0] + + # perform subsampling + if self.subsampling_factor > 1: + xs = [x[::self.subsampling_factor, :] for x in xs] + + # get batch of lengths of input sequences + ilens = np.array([x.shape[0] for x in xs]) + + # perform padding and convert to tensor + # currently only support real number + if xs[0].dtype.kind == 'c': + xs_pad_real = pad_list( + [torch.from_numpy(x.real).float() for x in xs], 0).to(device) + xs_pad_imag = pad_list( + [torch.from_numpy(x.imag).float() for x in xs], 0).to(device) + # Note(kamo): + # {'real': ..., 'imag': ...} will be changed to ComplexTensor in E2E. + # Don't create ComplexTensor and give it E2E here + # because torch.nn.DataParellel can't handle it. + xs_pad = {'real': xs_pad_real, 'imag': xs_pad_imag} + else: + xs_pad = pad_list([torch.from_numpy(x).float() for x in xs], 0).to(device) + + ilens = torch.from_numpy(ilens).to(device) + # NOTE: this is for multi-task learning (e.g., speech translation) + ys_pad = pad_list([torch.from_numpy(np.array(y[0]) if isinstance(y, tuple) else y).long() + for y in ys], self.ignore_id).to(device) + + return xs_pad, ilens, ys_pad + +def load_trained_model(model_path): + """Load the trained model. + + Args: + model_path(str): Path to model.***.best + """ + + # read training config + idim, odim, train_args = get_model_conf( + model_path, os.path.join(os.path.dirname(model_path), 'model.json')) + + # load trained model parameters + logging.info('reading model parameters from ' + model_path) + # To be compatible with v.0.3.0 models + if hasattr(train_args, "model_module"): + model_module = train_args.model_module + else: + model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E" + model_class = dynamic_import(model_module) + model = model_class(idim, odim, train_args) + torch_load(model_path, model) + + return model, train_args + + +def train(args): + """Train with the given args + + Args: + args (Namespace): The program arguments + """ + + set_deterministic_pytorch(args) + + # check cuda availability + if not torch.cuda.is_available(): + logging.warning('cuda is not available') + + # get input and output dimension info + with open(args.valid_json, 'rb') as f: + valid_json = json.load(f)['utts'] + utts = list(valid_json.keys()) + idim = int(valid_json[utts[0]]['input'][0]['shape'][1]) + odim = int(valid_json[utts[0]]['output'][0]['shape'][1]) + logging.info('#input dims : ' + str(idim)) + logging.info('#output dims: ' + str(odim)) + + # define rnn-trans or rnnt-trans with att. + if args.rnnt_mode == 1: + logging.info('MODE: RNN-Transducer with attention') + else: + logging.info('MODE: RNN-Transducer') + + # specify model architecture + model_class = dynamic_import(args.model_module) + model = model_class(idim, odim, args) + assert isinstance(model, ASRInterface) + + freeze_mode = False + if args.resume is None and \ + (args.enc_init is not None or args.dec_init is not None): + model = load_pretrained_modules(model, args.rnnt_mode, + args.enc_init, args.dec_init, + args.enc_init_mods, args.dec_init_mods) + if args.freeze_modules: + freeze_mode = freeze_modules(model, args.freeze_modules) + + subsampling_factor = model.subsample[0] + + if args.rnnlm is not None: + rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) + rnnlm = lm_pytorch.ClassifierWithState( + lm_pytorch.RNNLM( + len(args.char_list), rnnlm_args.layer, rnnlm_args.unit)) + torch.load(args.rnnlm, rnnlm) + model.rnnlm = rnnlm + + # write model config + if not os.path.exists(args.outdir): + os.makedirs(args.outdir) + model_conf = args.outdir + '/model.json' + with open(model_conf, 'wb') as f: + logging.info('writing a model config file to ' + model_conf) + f.write(json.dumps((idim, odim, vars(args)), indent=4, + sort_keys=True).encode('utf_8')) + for key in sorted(vars(args).keys()): + logging.info('ARGS: ' + key + ': ' + str(vars(args)[key])) + + reporter = model.reporter + + # check the use of multi-gpu + if args.ngpu > 1: + model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu))) + logging.info('batch size is automatically increased (%d -> %d)' % ( + args.batch_size, args.batch_size * args.ngpu)) + args.batch_size *= args.ngpu + + # set torch device + device = torch.device("cuda" if args.ngpu > 0 else "cpu") + model = model.to(device) + + # Setup an optimizer + if freeze_mode: + params = filter(lambda p: p.requires_grad, model.parameters()) + else: + params = model.parameters() + + if args.opt == 'adadelta': + optimizer = torch.optim.Adadelta( + params, rho=0.95, eps=args.eps, + weight_decay=args.weight_decay) + elif args.opt == 'adam': + optimizer = torch.optim.Adam(params, + weight_decay=args.weight_decay) + else: + raise NotImplementedError("unknown optimizer: " + args.opt) + + # FIXME: TOO DIRTY HACK + setattr(optimizer, "target", reporter) + setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) + + # Setup a converter + converter = CustomConverter(subsampling_factor=subsampling_factor) + + # read json data + with open(args.train_json, 'rb') as f: + train_json = json.load(f)['utts'] + with open(args.valid_json, 'rb') as f: + valid_json = json.load(f)['utts'] + + use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 + # make minibatch list (variable length) + train = make_batchset(train_json, args.batch_size, + args.maxlen_in, args.maxlen_out, args.minibatches, + min_batch_size=args.ngpu if args.ngpu > 1 else 1, + shortest_first=use_sortagrad, + count=args.batch_count, + batch_bins=args.batch_bins, + batch_frames_in=args.batch_frames_in, + batch_frames_out=args.batch_frames_out, + batch_frames_inout=args.batch_frames_inout) + valid = make_batchset(valid_json, args.batch_size, + args.maxlen_in, args.maxlen_out, args.minibatches, + min_batch_size=args.ngpu if args.ngpu > 1 else 1, + count=args.batch_count, + batch_bins=args.batch_bins, + batch_frames_in=args.batch_frames_in, + batch_frames_out=args.batch_frames_out, + batch_frames_inout=args.batch_frames_inout) + + load_tr = LoadInputsAndTargets( + mode='asr', load_output=True, preprocess_conf=args.preprocess_conf, + preprocess_args={'train': True} # Switch the mode of preprocessing + ) + load_cv = LoadInputsAndTargets( + mode='asr', load_output=True, preprocess_conf=args.preprocess_conf, + preprocess_args={'train': False} # Switch the mode of preprocessing + ) + + # hack to make batchsize argument as 1 + # actual bathsize is included in a list + if args.n_iter_processes > 0: + train_iter = ToggleableShufflingMultiprocessIterator( + TransformDataset(train, load_tr), + batch_size=1, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20, + shuffle=not use_sortagrad) + valid_iter = ToggleableShufflingMultiprocessIterator( + TransformDataset(valid, load_cv), + batch_size=1, repeat=False, shuffle=False, + n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20) + else: + train_iter = ToggleableShufflingSerialIterator( + TransformDataset(train, load_tr), + batch_size=1, shuffle=not use_sortagrad) + valid_iter = ToggleableShufflingSerialIterator( + TransformDataset(valid, load_cv), + batch_size=1, repeat=False, shuffle=False) + + # Set up a trainer + updater = CustomUpdater( + model, args.grad_clip, train_iter, optimizer, converter, device, args.ngpu, + args.grad_noise, args.accum_grad) + trainer = training.Trainer( + updater, (args.epochs, 'epoch'), out=args.outdir) + + if use_sortagrad: + trainer.extend(ShufflingEnabler([train_iter]), + trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, 'epoch')) + + # Resume from a snapshot + if args.resume: + logging.info('resumed from %s' % args.resume) + torch_resume(args.resume, trainer) + + # Evaluate the model with the test dataset for each epoch + trainer.extend(CustomEvaluator(model, valid_iter, reporter, converter, device)) + + # Save attention weight each epoch + if args.num_save_attention > 0 and args.rnnt_mode == 1: + data = sorted(list(valid_json.items())[:args.num_save_attention], + key=lambda x: int(x[1]['input'][0]['shape'][1]), reverse=True) + if hasattr(model, "module"): + att_vis_fn = model.module.calculate_all_attentions + plot_class = model.module.attention_plot_class + else: + att_vis_fn = model.calculate_all_attentions + plot_class = model.attention_plot_class + att_reporter = plot_class( + att_vis_fn, data, args.outdir + "/att_ws", + converter=converter, transform=load_cv, device=device) + trainer.extend(att_reporter, trigger=(1, 'epoch')) + else: + att_reporter = None + + # Make a plot for training and validation values + trainer.extend(extensions.PlotReport(['main/loss', 'validation/main/loss'], + 'epoch', file_name='loss.png')) + + # Save best models + trainer.extend(snapshot_object(model, 'model.loss.best'), + trigger=training.triggers.MinValueTrigger('validation/main/loss')) + + # save snapshot which contains model and optimizer states + trainer.extend(torch_snapshot(), trigger=(1, 'epoch')) + + # epsilon decay in the optimizer + if args.opt == 'adadelta': + trainer.extend(restore_snapshot(model, args.outdir + '/model.loss.best', load_fn=torch_load), + trigger=CompareValueTrigger( + 'validation/main/loss', + lambda best_value, current_value: best_value < current_value)) + trainer.extend(adadelta_eps_decay(args.eps_decay), + trigger=CompareValueTrigger( + 'validation/main/loss', + lambda best_value, current_value: best_value < current_value)) + + # Write a log of evaluation statistics for each epoch + trainer.extend(extensions.LogReport(trigger=(REPORT_INTERVAL, 'iteration'))) + + report_keys = ['epoch', 'iteration', 'main/loss', 'validation/main/loss', + 'elapsed_time'] + + if args.opt == 'adadelta': + trainer.extend(extensions.observe_value( + 'eps', lambda trainer: trainer.updater.get_optimizer('main').param_groups[0]["eps"]), + trigger=(REPORT_INTERVAL, 'iteration')) + report_keys.append('eps') + if args.report_cer: + report_keys.append('validation/main/cer') + if args.report_wer: + report_keys.append('validation/main/wer') + trainer.extend(extensions.PrintReport( + report_keys), trigger=(REPORT_INTERVAL, 'iteration')) + + trainer.extend(extensions.ProgressBar(update_interval=REPORT_INTERVAL)) + set_early_stop(trainer, args) + + if args.tensorboard_dir is not None and args.tensorboard_dir != "": + writer = SummaryWriter(args.tensorboard_dir) + trainer.extend(TensorboardLogger(writer, att_reporter)) + # Run the training + trainer.run() + check_early_stop(trainer, args.epochs) + +def recog(args): + """Decode with the given args. + + Args: + args (Namespace): The program arguments + """ + + set_deterministic_pytorch(args) + model, train_args = load_trained_model(args.model) + + assert isinstance(model, ASRInterface) + model.recog_args = args + + # read rnnlm + if args.rnnlm: + rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) + rnnlm = lm_pytorch.ClassifierWithState( + lm_pytorch.RNNLM( + len(train_args.char_list), rnnlm_args.layer, rnnlm_args.unit)) + torch_load(args.rnnlm, rnnlm) + rnnlm.eval() + else: + rnnlm = None + + if args.word_rnnlm: + rnnlm_args = get_model_conf(args.word_rnnlm, args.word_rnnlm_conf) + word_dict = rnnlm_args.char_list_dict + char_dict = {x: i for i, x in enumerate(train_args.char_list)} + word_rnnlm = lm_pytorch.ClassifierWithState(lm_pytorch.RNNLM( + len(word_dict), rnnlm_args.layer, rnnlm_args.unit)) + + torch_load(args.word_rnnlm, word_rnnlm) + word_rnnlm.eval() + + if rnnlm is not None: + rnnlm = lm_pytorch.ClassifierWithState( + extlm_pytorch.MultiLevelLM(word_rnnlm.predictor, + rnnlm.predictor, word_dict, char_dict)) + else: + rnnlm = lm_pytorch.ClassifierWithState( + extlm_pytorch.LookAheadWordLM(word_rnnlm.predictor, + word_dict, char_dict)) + + # gpu + if args.ngpu == 1: + gpu_id = range(args.ngpu) + logging.info('gpu id: ' + str(gpu_id)) + model.cuda() + if rnnlm: + rnnlm.cuda() + + # read json data + with open(args.recog_json, 'rb') as f: + js = json.load(f)['utts'] + new_js = {} + + load_inputs_and_targets = LoadInputsAndTargets( + mode='asr', load_output=False, sort_in_input_length=False, + preprocess_conf=train_args.preprocess_conf + if args.preprocess_conf is None else args.preprocess_conf, + preprocess_args={'train': False}) + + if args.batchsize <= 1: + with torch.no_grad(): + for idx, name in enumerate(js.keys(), 1): + logging.info('(%d/%d) decoding ' + name, idx, len(js.keys())) + batch = [(name, js[name])] + feat = load_inputs_and_targets(batch)[0][0] + nbest_hyps = model.recognize(feat, args, train_args.char_list, rnnlm) + new_js[name] = add_results_to_json(js[name], nbest_hyps, train_args.char_list) + else: + raise NotImplementedError + + def grouper(n, iterable, fillvalue=None): + kargs = [iter(iterable)] * n + return zip_longest(*kargs, fillvalue=fillvalue) + + # sort data + keys = list(js.keys()) + feat_lens = [js[key]['input'][0]['shape'][0] for key in keys] + sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i]) + keys = [keys[i] for i in sorted_index] + + with torch.no_grad(): + for names in grouper(args.batchsize, keys, None): + names = [name for name in names if name] + batch = [(name, js[name]) for name in names] + feats = load_inputs_and_targets(batch)[0] + nbest_hyps = model.recognize_batch(feats, args, train_args.char_list, rnnlm) + + for i, nbest_hyp in enumerate(nbest_hyps): + name = names[i] + new_js[name] = add_results_to_json(js[name], nbest_hyp, train_args.char_list) + + with open(args.result_label, 'wb') as f: + f.write(json.dumps({'utts': new_js}, indent=4, + ensure_ascii=False, sort_keys=True).encode('utf_8')) + +def enchance(args): + """Dumping enhanced speech and mask + + Args: + args (Namespace): The program arguments + """ + + set_deterministic_pytorch(args) + # read training config + idim, odim, train_args = get_model_conf(args.model, args.model_conf) + + # load trained model parameters + logging.info('reading model parameters from ' + args.model) + model_class = dynamic_import(train_args.model_module) + model = model_class(idim, odim, train_args) + assert isinstance(model, ASRInterface) + torch_load(args.model, model) + model.recog_args = args + + # gpu + if args.ngpu == 1: + gpu_id = list(range(args.ngpu)) + logging.info('gpu id: ' + str(gpu_id)) + model.cuda() + + # read json data + with open(args.recog_json, 'rb') as f: + js = json.load(f)['utts'] + + load_inputs_and_targets = LoadInputsAndTargets( + mode='asr', load_output=False, sort_in_input_length=False, + preprocess_conf=None # Apply pre_process in outer func + ) + if args.batchsize == 0: + args.batchsize = 1 + + # Creates writers for outputs from the network + if args.enh_wspecifier is not None: + enh_writer = FileWriterWrapper(args.enh_wspecifier, + filetype=args.enh_filetype) + else: + enh_writer = None + + # Creates a Transformation instance + preprocess_conf = ( + train_args.preprocess_conf if args.preprocess_conf is None + else args.preprocess_conf) + if preprocess_conf is not None: + logging.info('Use preprocessing'.format(preprocess_conf)) + transform = Transformation(preprocess_conf) + else: + transform = None + + # Creates a IStft instance + istft = None + frame_shift = args.istft_n_shift # Used for plot the spectrogram + if args.apply_istft: + if preprocess_conf is not None: + # Read the conffile and find stft setting + with open(preprocess_conf) as f: + # Json format: e.g. + # {"process": [{"type": "stft", + # "win_length": 400, + # "n_fft": 512, "n_shift": 160, + # "window": "han"}, + # {"type": "foo", ...}, ...]} + conf = json.load(f) + assert 'process' in conf, conf + # Find stft setting + for p in conf['process']: + if p['type'] == 'stft': + istft = IStft(win_length=p['win_length'], + n_shift=p['n_shift'], + window=p.get('window', 'hann')) + logging.info('stft is found in {}. ' + 'Setting istft config from it\n{}' + .format(preprocess_conf, istft)) + frame_shift = p['n_shift'] + break + if istft is None: + # Set from command line arguments + istft = IStft(win_length=args.istft_win_length, + n_shift=args.istft_n_shift, + window=args.istft_window) + logging.info('Setting istft config from the command line args\n{}' + .format(istft)) + + keys = list(js.keys()) + feat_lens = [js[key]['input'][0]['shape'][0] for key in keys] + sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i]) + keys = [keys[i] for i in sorted_index] + + def grouper(n, iterable, fillvalue=None): + kargs = [iter(iterable)] * n + return zip_longest(*kargs, fillvalue=fillvalue) + + num_images = 0 + if not os.path.exists(args.image_dir): + os.makedirs(args.image_dir) + + for names in grouper(args.batchsize, keys, None): + batch = [(name, js[name]) for name in names] + + # May be in time region: (Batch, [Time, Channel]) + org_feats = load_inputs_and_targets(batch)[0] + if transform is not None: + # May be in time-freq region: : (Batch, [Time, Channel, Freq]) + feats = transform(org_feats, train=False) + else: + feats = org_feats + + with torch.no_grad(): + enhanced, mask, ilens = model.enhance(feats) + + for idx, name in enumerate(names): + # Assuming mask, feats : [Batch, Time, Channel. Freq] + # enhanced : [Batch, Time, Freq] + enh = enhanced[idx][:ilens[idx]] + mas = mask[idx][:ilens[idx]] + feat = feats[idx] + + # Plot spectrogram + if args.image_dir is not None and num_images < args.num_images: + import matplotlib.pyplot as plt + num_images += 1 + ref_ch = 0 + + plt.figure(figsize=(20, 10)) + plt.subplot(4, 1, 1) + plt.title('Mask [ref={}ch]'.format(ref_ch)) + plot_spectrogram(plt, mas[:, ref_ch].T, fs=args.fs, + mode='linear', frame_shift=frame_shift, + bottom=False, labelbottom=False) + + plt.subplot(4, 1, 2) + plt.title('Noisy speech [ref={}ch]'.format(ref_ch)) + plot_spectrogram(plt, feat[:, ref_ch].T, fs=args.fs, + mode='db', frame_shift=frame_shift, + bottom=False, labelbottom=False) + + plt.subplot(4, 1, 3) + plt.title('Masked speech [ref={}ch]'.format(ref_ch)) + plot_spectrogram( + plt, (feat[:, ref_ch] * mas[:, ref_ch]).T, + frame_shift=frame_shift, + fs=args.fs, mode='db', bottom=False, labelbottom=False) + + plt.subplot(4, 1, 4) + plt.title('Enhanced speech') + plot_spectrogram(plt, enh.T, fs=args.fs, + mode='db', frame_shift=frame_shift) + + plt.savefig(os.path.join(args.image_dir, name + '.png')) + plt.clf() + + # Write enhanced wave files + if enh_writer is not None: + if istft is not None: + enh = istft(enh) + else: + enh = enh + + if args.keep_length: + if len(org_feats[idx]) < len(enh): + # Truncate the frames added by stft padding + enh = enh[:len(org_feats[idx])] + elif len(org_feats) > len(enh): + padwidth = [(0, (len(org_feats[idx]) - len(enh)))] \ + + [(0, 0)] * (enh.ndim - 1) + enh = np.pad(enh, padwidth, mode='constant') + + if args.enh_filetype in ('sound', 'sound.hdf5'): + enh_writer[name] = (args.fs, enh) + else: + # Hint: To dump stft_signal, mask or etc, + # enh_filetype='hdf5' might be convenient. + enh_writer[name] = enh + + if num_images >= args.num_images and enh_writer is None: + logging.info('Breaking the process.') + break diff --git a/espnet/bin/asr_rnnt_recog.py b/espnet/bin/asr_rnnt_recog.py new file mode 100755 index 00000000000..854a9126307 --- /dev/null +++ b/espnet/bin/asr_rnnt_recog.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python +# encoding: utf-8 + +import configargparse +import logging +import os +import random +import sys + +import numpy as np + +from espnet.utils.cli_utils import strtobool + +def get_parser(): + parser = configargparse.ArgumentParser( + description='Transcribe text from speech using a speech recognition model on one CPU or GPU', + config_file_parser_class=configargparse.YAMLConfigFileParser, + formatter_class=configargparse.ArgumentDefaultsHelpFormatter) + # general configuration + parser.add('--config', is_config_file=True, + help='Config file path') + parser.add('--config2', is_config_file=True, + help='Second config file path that overwrites the settings in `--config`') + parser.add('--config3', is_config_file=True, + help='Third config file path that overwrites the settings in `--config` and `--config2`') + parser.add_argument('--ngpu', default=0, type=int, + help='Number of GPUs') + parser.add_argument('--backend', default='pytorch', type=str, + choices=['pytorch'], + help='Backend library') + parser.add_argument('--debugmode', default=1, type=int, + help='Debugmode') + parser.add_argument('--seed', default=1, type=int, + help='Random seed') + parser.add_argument('--verbose', '-V', default=1, type=int, + help='Verbose option') + parser.add_argument('--batchsize', default=0, type=int, + help='Batch size for beam search (0: means no batch processing)') + parser.add_argument('--preprocess-conf', type=str, default=None, + help='The configuration file for the pre-processing') + # task related + parser.add_argument('--recog-json', type=str, + help='Filename of recognition data (json)') + parser.add_argument('--result-label', type=str, required=True, + help='Filename of result label data (json)') + # model (parameter) related + parser.add_argument('--model', type=str, required=True, + help='Model file parameters to read') + parser.add_argument('--model-conf', type=str, default=None, + help='Model config file') + # search related + parser.add_argument('--search-type', type=str, default='beam', + choices=['greedy', 'beam'], + help='Search algorithm to use.') + parser.add_argument('--nbest', type=int, default=1, + help='Output N-best hypotheses') + parser.add_argument('--beam-size', type=int, default=1, + help='Beam size') + parser.add_argument('--score-norm', type=strtobool, nargs='?', + default=True, + help='Length score normalization for beam search') + # rnnlm related + parser.add_argument('--rnnlm', type=str, default=None, + help='RNNLM model file to read') + parser.add_argument('--rnnlm-conf', type=str, default=None, + help='RNNLM model config file to read') + parser.add_argument('--word-rnnlm', type=str, default=None, + help='Word RNNLM model file to read') + parser.add_argument('--word-rnnlm-conf', type=str, default=None, + help='Word RNNLM model config file to read') + parser.add_argument('--word-dict', type=str, default=None, + help='Word list to read') + parser.add_argument('--lm-weight', default=0.1, type=float, + help='RNNLM weight.') + + return parser + +def main(args): + parser = get_parser() + args = parser.parse_args(args) + + # logging info + if args.verbose == 1: + logging.basicConfig( + level=logging.INFO, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") + elif args.verbose == 2: + logging.basicConfig(level=logging.DEBUG, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") + else: + logging.basicConfig( + level=logging.WARN, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") + logging.warning("Skip DEBUG/INFO messages") + + # check CUDA_VISIBLE_DEVICES + if args.ngpu > 0: + cvd = os.environ.get("CUDA_VISIBLE_DEVICES") + if cvd is None: + logging.warning("CUDA_VISIBLE_DEVICES is not set.") + elif args.ngpu != len(cvd.split(",")): + logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.") + sys.exit(1) + + # TODO(mn5k): support of multiple GPUs + if args.ngpu > 1: + logging.error("The program only supports ngpu=1.") + sys.exit(1) + + # display PYTHONPATH + logging.info('python path = ' + os.environ.get('PYTHONPATH', '(None)')) + + # seed setting + random.seed(args.seed) + np.random.seed(args.seed) + logging.info('set random seed = %d' % args.seed) + + # validate rnn options + if args.rnnlm is not None and args.word_rnnlm is not None: + logging.error("It seems that both --rnnlm and --word-rnnlm are specified. Please use either option.") + sys.exit(1) + + # recog + logging.info('backend = ' + args.backend) + if args.backend == "pytorch": + from espnet.asr.pytorch_backend.asr_rnnt import recog + recog(args) + else: + raise ValueError("Only pytorch is supported for RNN-Transducer.") + +if __name__ == '__main__': + main(sys.argv[1:]) diff --git a/espnet/bin/asr_rnnt_train.py b/espnet/bin/asr_rnnt_train.py new file mode 100755 index 00000000000..5e5eccb7437 --- /dev/null +++ b/espnet/bin/asr_rnnt_train.py @@ -0,0 +1,348 @@ +#!/usr/bin/env python +# encoding: utf-8 + +import configargparse +import logging +import os +import random +import subprocess +import sys + +import numpy as np + +from espnet.utils.cli_utils import strtobool +from espnet.utils.training.batchfy import BATCH_COUNT_CHOICES + +def get_parser(): + parser = configargparse.ArgumentParser( + description="Train an automatic speech recognition (ASR) model on one CPU, one or multiple GPUs", + config_file_parser_class=configargparse.YAMLConfigFileParser, + formatter_class=configargparse.ArgumentDefaultsHelpFormatter) + + # general configuration + parser.add('--config', is_config_file=True, help='config file path') + parser.add('--config2', is_config_file=True, + help='second config file path that overwrites the settings in `--config`.') + parser.add('--config3', is_config_file=True, + help='third config file path that overwrites the settings in `--config` and `--config2`.') + parser.add_argument('--ngpu', default=0, type=int, + help='Number of GPUs') + parser.add_argument('--backend', default='pytorch', type=str, + choices=['pytorch'], + help='Backend library') + parser.add_argument('--outdir', type=str, required=True, + help='Output directory') + parser.add_argument('--debugmode', default=1, type=int, + help='Debugmode') + parser.add_argument('--dict', required=True, + help='Dictionary') + parser.add_argument('--seed', default=1, type=int, + help='Random seed') + parser.add_argument('--debugdir', type=str, + help='Output directory for debugging') + parser.add_argument('--resume', type=str, nargs='?', + help='Resume the training from snapshot') + parser.add_argument('--minibatches', '-N', type=int, default='-1', + help='Process only N minibatches (for debug)') + parser.add_argument('--verbose', '-V', default=0, type=int, + help='Verbose option') + parser.add_argument('--tensorboard-dir', default=None, type=str, nargs='?', help="Tensorboard log dir path") + # task related + parser.add_argument('--train-json', type=str, default=None, + help='Filename of train label data (json)') + parser.add_argument('--valid-json', type=str, default=None, + help='Filename of validation label data (json)') + # network architecture + parser.add_argument('--model-module', type=str, default=None, + help='model defined module (default: espnet.nets.xxx_backend.e2e_asr)') + # encoder + parser.add_argument('--etype', default='blstmp', type=str, + choices=['lstm', 'blstm', 'lstmp', 'blstmp', 'vgglstmp', 'vggblstmp', 'vgglstm', 'vggblstm', + 'gru', 'bgru', 'grup', 'bgrup', 'vgggrup', 'vggbgrup', 'vgggru', 'vggbgru'], + help='Type of encoder network architecture') + parser.add_argument('--elayers', default=4, type=int, + help='Number of encoder layers (for shared recognition part in multi-speaker asr mode)') + parser.add_argument('--eunits', '-u', default=300, type=int, + help='Number of encoder hidden units') + parser.add_argument('--eprojs', default=320, type=int, + help='Number of encoder projection units') + parser.add_argument('--subsample', default=1, type=str, + help='Subsample input frames x_y_z means subsample every x frame at 1st layer, ' + 'every y frame at 2nd layer etc.') + # decoder related + parser.add_argument('--dtype', default='lstm', type=str, + choices=['lstm', 'gru'], + help='Type of decoder network architecture') + parser.add_argument('--dlayers', default=1, type=int, + help='Number of decoder layers') + parser.add_argument('--dunits', default=320, type=int, + help='Number of decoder hidden units') + parser.add_argument('--dec-embed-dim', default=320, type=int, + help='Number of decoder embeddings dimensions') + # attention related + parser.add_argument('--atype', default='dot', type=str, + choices=['noatt', 'dot', 'add', 'location', 'coverage', + 'coverage_location', 'location2d', 'location_recurrent', + 'multi_head_dot', 'multi_head_add', 'multi_head_loc', + 'multi_head_multi_res_loc'], + help='Type of attention architecture') + parser.add_argument('--adim', default=320, type=int, + help='Number of attention transformation dimensions') + parser.add_argument('--awin', default=5, type=int, + help='Window size for location2d attention') + parser.add_argument('--aheads', default=4, type=int, + help='Number of heads for multi head attention') + parser.add_argument('--aconv-chans', default=-1, type=int, + help='Number of attention convolution channels \ + (negative value indicates no location-aware attention)') + parser.add_argument('--aconv-filts', default=100, type=int, + help='Number of attention convolution filters \ + (negative value indicates no location-aware attention)') + parser.add_argument('--spa', action='store_true', + help='Enable speaker parallel attention.') + # loss + parser.add_argument('--rnnt_type', default='warp-transducer', type=str, + choices=['warp-transducer'], + help='Type of RNN Transducer implementation to calculate loss.') + # recognition options to compute CER/WER + parser.add_argument('--report-cer', default=False, action='store_true', + help='Compute CER on development set') + parser.add_argument('--report-wer', default=False, action='store_true', + help='Compute WER on development set') + parser.add_argument('--nbest', type=int, default=1, + help='Output N-best hypotheses') + parser.add_argument('--beam-size', type=int, default=4, + help='Beam size') + parser.add_argument('--rnnlm', type=str, default=None, + help='RNNLM model file to read') + parser.add_argument('--rnnlm-conf', type=str, default=None, + help='RNNLM model config file to read') + parser.add_argument('--lm-weight', default=0.1, type=float, + help='RNNLM weight.') + parser.add_argument('--sym-space', default='', type=str, + help='Space symbol') + parser.add_argument('--sym-blank', default='', type=str, + help='Blank symbol') + # model related + parser.add_argument('--rnnt-mode', default=0, type=int, choices=[0, 1], + help='RNN-Transducing mode (0:rnnt, 1:rnnt-att)') + parser.add_argument('--joint-dim', default=320, type=int, + help='Number of dimensions in joint space') + # model (parameter) related + parser.add_argument('--dropout-rate', default=0.0, type=float, + help='Dropout rate for the encoder') + parser.add_argument('--dropout-rate-decoder', default=0.0, type=float, + help='Dropout rate for the decoder') + parser.add_argument('--dropout-rate-embed-decoder', default=0.0, type=float, + help='Dropout rate for the decoder embeddings') + # minibatch related + parser.add_argument('--sortagrad', default=0, type=int, nargs='?', + help="How many epochs to use sortagrad for. 0 = deactivated, -1 = all epochs") + parser.add_argument('--batch-count', default='auto', choices=BATCH_COUNT_CHOICES, + help='How to count batch_size. The default (auto) will find how to count by args.') + parser.add_argument('--batch-size', '-b', default=50, type=int, + help='Batch size') + parser.add_argument('--batch-bins', default=0, type=int, + help='Maximum bins in a minibatch (0 to disable)') + parser.add_argument('--batch-frames-in', default=0, type=int, + help='Maximum input frames in a minibatch (0 to disable)') + parser.add_argument('--batch-frames-out', default=0, type=int, + help='Maximum output frames in a minibatch (0 to disable)') + parser.add_argument('--batch-frames-inout', default=0, type=int, + help='Maximum input+output frames in a minibatch (0 to disable)') + parser.add_argument('--maxlen-in', default=800, type=int, metavar='ML', + help='Batch size is reduced if the input sequence length > ML') + parser.add_argument('--maxlen-out', default=150, type=int, metavar='ML', + help='Batch size is reduced if the output sequence length > ML') + parser.add_argument('--n_iter_processes', default=0, type=int, + help='Number of processes of iterator') + parser.add_argument('--preprocess-conf', type=str, default=None, + help='The configuration file for the pre-processing') + # optimization related + parser.add_argument('--opt', default='adadelta', type=str, + choices=['adadelta', 'adam'], + help='Optimizer') + parser.add_argument('--accum-grad', default=1, type=int, + help='Number of gradient accumulation') + parser.add_argument('--eps', default=1e-8, type=float, + help='Epsilon constant for optimizer') + parser.add_argument('--eps-decay', default=0.01, type=float, + help='Decaying ratio of epsilon') + parser.add_argument('--weight-decay', default=0.0, type=float, + help='Weight decay ratio') + parser.add_argument('--criterion', default='loss', type=str, + choices=['loss'], + help='Criterion to perform epsilon decay') + parser.add_argument('--threshold', default=1e-4, type=float, + help='Threshold to stop iteration') + parser.add_argument('--epochs', '-e', default=30, type=int, + help='Maximum number of epochs') + parser.add_argument('--early-stop-criterion', default='validation/main/loss', type=str, nargs='?', + help="Value to monitor to trigger an early stopping of the training") + parser.add_argument('--patience', default=3, type=int, nargs='?', + help="Number of epochs to wait without improvement before stopping the training") + parser.add_argument('--grad-clip', default=5, type=float, + help='Gradient norm threshold to clip') + parser.add_argument('--num-save-attention', default=3, type=int, + help='Number of samples of attention to be saved') + parser.add_argument('--grad-noise', type=strtobool, default=False, + help='The flag to switch to use noise injection to gradients during training') + # finetuning related + parser.add_argument('--enc-init', default=None, type=str, + help='Initialize encoder model part from pre-trained ESPNET ASR model.') + parser.add_argument('--enc-init-mods', default='enc.enc.', + type=lambda s: [str(mod) for mod in s.split(',') if s != ''], + help='List of encoder modules to initialize, separated by a comma.') + parser.add_argument('--dec-init', default=None, type=str, + help='Initialize decoder model part from pre-trained ESPNET ASR or LM model.') + parser.add_argument('--dec-init-mods', default='att.,dec.decoder.,dec.att.,dec.embed.', + type=lambda s: [str(mod) for mod in s.split(',') if s != ''], + help='List of decoder modules to initialize, separated by a comma.') + parser.add_argument('--freeze-modules', default='', + type=lambda s: [str(mod) for mod in s.split(',') if s != ''], + help='List of modules to freeze, separated by a comma.') + # speech translation related + parser.add_argument('--use-frontend', type=strtobool, default=False, + help='The flag to switch to use frontend system.') + # WPE related + parser.add_argument('--use-wpe', type=strtobool, default=False, + help='Apply Weighted Prediction Error') + parser.add_argument('--wtype', default='blstmp', type=str, + choices=['lstm', 'blstm', 'lstmp', 'blstmp', 'vgglstmp', 'vggblstmp', 'vgglstm', 'vggblstm', 'gru', 'bgru', 'grup', 'bgrup', 'vgggrup', 'vggbgrup', 'vgggru', 'vggbgru'], + help='Type of encoder network architecture of the mask estimator for WPE.') + parser.add_argument('--wlayers', type=int, default=2, + help='') + parser.add_argument('--wunits', type=int, default=300, + help='') + parser.add_argument('--wprojs', type=int, default=300, + help='') + parser.add_argument('--wdropout-rate', type=float, default=0.0, + help='') + parser.add_argument('--wpe-taps', type=int, default=5, + help='') + parser.add_argument('--wpe-delay', type=int, default=3, + help='') + parser.add_argument('--use-dnn-mask-for-wpe', type=strtobool, + default=False, + help='Use DNN to estimate the power spectrogram. ' + 'This option is experimental.') + # Beamformer related + parser.add_argument('--use-beamformer', type=strtobool, + default=True, help='') + parser.add_argument('--btype', default='blstmp', type=str, + choices=['lstm', 'blstm', 'lstmp', 'blstmp', 'vgglstmp', 'vggblstmp', 'vgglstm', 'vggblstm', 'gru', 'bgru', 'grup', 'bgrup', 'vgggrup', 'vggbgrup', 'vgggru', 'vggbgru'], + help='Type of encoder network architecture ' + 'of the mask estimator for Beamformer.') + parser.add_argument('--blayers', type=int, default=2, + help='') + parser.add_argument('--bunits', type=int, default=300, + help='') + parser.add_argument('--bprojs', type=int, default=300, + help='') + parser.add_argument('--badim', type=int, default=320, + help='') + parser.add_argument('--ref-channel', type=int, default=-1, + help='The reference channel used for beamformer. ' + 'By default, the channel is estimated by DNN.') + parser.add_argument('--bdropout-rate', type=float, default=0.0, + help='') + # Feature transform: Normalization + parser.add_argument('--stats-file', type=str, default=None, + help='The stats file for the feature normalization') + parser.add_argument('--apply-uttmvn', type=strtobool, default=True, + help='Apply utterance level mean ' + 'variance normalization.') + parser.add_argument('--uttmvn-norm-means', type=strtobool, + default=True, help='') + parser.add_argument('--uttmvn-norm-vars', type=strtobool, default=False, + help='') + # Feature transform: Fbank + parser.add_argument('--fbank-fs', type=int, default=16000, + help='The sample frequency used for ' + 'the mel-fbank creation.') + parser.add_argument('--n-mels', type=int, default=80, + help='The number of mel-frequency bins.') + parser.add_argument('--fbank-fmin', type=float, default=0., + help='') + parser.add_argument('--fbank-fmax', type=float, default=None, + help='') + + return parser + +def main(cmd_args): + parser = get_parser() + args, _ = parser.parse_known_args(cmd_args) + + from espnet.utils.dynamic_import import dynamic_import + if args.model_module is not None: + model_class = dynamic_import(args.model_module) + model_class.add_arguments(parser) + args = parser.parse_args(cmd_args) + if args.model_module is None: + args.model_module = "espnet.nets." + args.backend + "_backend.e2e_asr_rnnt:E2E" + if 'pytorch_backend' in args.model_module: + args.backend = 'pytorch' + + # logging info + if args.verbose > 0: + logging.basicConfig( + level=logging.INFO, format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s') + else: + logging.basicConfig( + level=logging.WARN, format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s') + logging.warning('Skip DEBUG/INFO messages') + + #check CUDA_VISIBLE_DEVICES + # If --ngpu is not given, + # 1. if CUDA_VISIBLE_DEVICES is set, all visible devices + # 2. if nvidia-smi exists, use all devices + # 3. else ngpu=0 + if args.ngpu is None: + cvd = os.environ.get("CUDA_VISIBLE_DEVICES") + if cvd is not None: + ngpu = len(cvd.split(',')) + else: + logging.warning("CUDA_VISIBLE_DEVICES is not set.") + try: + p = subprocess.run(['nvidia-smi', '-L'], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + except (subprocess.CalledProcessError, FileNotFoundError): + ngpu = 0 + else: + ngpu = len(p.stderr.decode().split('\n')) - 1 + else: + ngpu = args.ngpu + logging.info(f"ngpu: {ngpu}") + + # display PYTHONPATH + logging.info('python path = ' + os.environ.get('PYTHONPATH', '(None)')) + + # set random seed + logging.info('random seed = %d' % args.seed) + random.seed(args.seed) + np.random.seed(args.seed) + + # load dictionary for debug log + if args.dict is not None: + with open(args.dict, 'rb') as f: + dictionary = f.readlines() + char_list = [entry.decode('utf-8').split(' ')[0] + for entry in dictionary] + char_list.insert(0, '') + char_list.append('') + args.char_list = char_list + else: + args.char_list = None + + # train + logging.info('backend = ' + args.backend) + + if args.backend == "pytorch": + from espnet.asr.pytorch_backend.asr_rnnt import train + train(args) + else: + raise ValueError("Only pytorch is supported for RNN-Transducer.") + +if __name__ == '__main__': + main(sys.argv[1:]) diff --git a/espnet/nets/pytorch_backend/e2e_asr_rnnt.py b/espnet/nets/pytorch_backend/e2e_asr_rnnt.py new file mode 100644 index 00000000000..15e22831110 --- /dev/null +++ b/espnet/nets/pytorch_backend/e2e_asr_rnnt.py @@ -0,0 +1,396 @@ +#!/usr/bin/env python + +# Copyright 2017 Johns Hopkins University (Shinji Watanabe) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +import argparse +import logging +import math + +import editdistance + +import chainer +import numpy as np +import six +import torch + +from itertools import groupby + +from chainer import reporter + +from espnet.nets.asr_interface import ASRInterface +from espnet.nets.pytorch_backend.rnn.encoders import encoder_for +from espnet.nets.pytorch_backend.rnn.attentions import att_for +from espnet.nets.pytorch_backend.rnn.decoders_rnnt import decoder_for + +from espnet.nets.pytorch_backend.nets_utils import pad_list +from espnet.nets.pytorch_backend.nets_utils import to_device +from espnet.nets.pytorch_backend.nets_utils import to_torch_tensor + +class Reporter(chainer.Chain): + """A chainer reporter wrapper""" + + def report(self, loss, cer, wer): + reporter.report({'cer': cer}, self) + reporter.report({'wer': wer}, self) + logging.info('loss:' + str(loss)) + reporter.report({'loss': loss}, self) + +class E2E(ASRInterface, torch.nn.Module): + """E2E module + + Args: + idim (int): dimension of inputs + odim (int): dimension of outputs + args (namespace): argument Namespace containing options + """ + + def __init__(self, idim, odim, args): + super(E2E, self).__init__() + torch.nn.Module.__init__(self) + self.rnnt_mode = args.rnnt_mode + self.etype = args.etype + self.verbose = args.verbose + self.char_list = args.char_list + self.outdir = args.outdir + self.space = args.sym_space + self.blank = args.sym_blank + self.reporter = Reporter() + + # note that eos is the same as sos (equivalent ID) + self.sos = odim - 1 + + # subsample info + # +1 means input (+1) and layers outputs (args.elayer) + subsample = np.ones(args.elayers + 1, dtype=np.int) + if args.etype.endswith("p") and not args.etype.startswith("vgg"): + ss = args.subsample.split("_") + for j in range(min(args.elayers + 1, len(ss))): + subsample[j] = int(ss[j]) + else: + logging.warning( + 'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.') + logging.info('subsample: ' + ' '.join([str(x) for x in subsample])) + self.subsample = subsample + + if args.use_frontend: + # Relative importing because of using python3 syntax + from espnet.nets.pytorch_backend.frontends.feature_transform \ + import feature_transform_for + from espnet.nets.pytorch_backend.frontends.frontend \ + import frontend_for + + self.frontend = frontend_for(args, idim) + self.feature_transform = feature_transform_for(args, (idim - 1) * 2) + idim = args.n_mels + else: + self.frontend = None + + # encoder + self.enc = encoder_for(args, idim, self.subsample) + + if args.rnnt_mode == 1: + # attention + self.att = att_for(args) + # decoder + self.dec = decoder_for(args, odim, self.sos, self.att) + else: + # prediction + self.dec = decoder_for(args, odim, self.sos) + # weight initialization + self.init_like_chainer() + + # options for beam search + if 'report_cer' in vars(args) and (args.report_cer or args.report_wer): + recog_args = {'beam_size': args.beam_size, 'nbest': args.nbest, + 'space': args.sym_space} + + self.recog_args = argparse.Namespace(**recog_args) + self.report_cer = args.report_cer + self.report_wer = args.report_wer + else: + self.report_cer = False + self.report_wer = False + + self.logzero = -10000000000.0 + self.rnnlm = None + self.loss = None + + def init_like_chainer(self): + """Initialize weight like chainer + + chainer basically uses LeCun way: W ~ Normal(0, fan_in ** -0.5), b = 0 + pytorch basically uses W, b ~ Uniform(-fan_in**-0.5, fan_in**-0.5) + + however, there are two exceptions as far as I know. + - EmbedID.W ~ Normal(0, 1) + - LSTM.upward.b[forget_gate_range] = 1 (but not used in NStepLSTM) + """ + + def lecun_normal_init_parameters(module): + for p in module.parameters(): + data = p.data + if data.dim() == 1: + # bias + data.zero_() + elif data.dim() == 2: + # linear weight + n = data.size(1) + stdv = 1. / math.sqrt(n) + data.normal_(0, stdv) + elif data.dim() == 4: + # conv weight + n = data.size(1) + for k in data.size()[2:]: + n *= k + stdv = 1. / math.sqrt(n) + data.normal_(0, stdv) + else: + raise NotImplementedError + + def set_forget_bias_to_one(bias): + n = bias.size(0) + start, end = n // 4, n // 2 + bias.data[start:end].fill_(1.) + + lecun_normal_init_parameters(self) + + if self.rnnt_mode == 1: + # embed weight ~ Normal(0, 1) + self.dec.embed.weight.data.normal_(0, 1) + # forget-bias = 1.0 + # https://discuss.pytorch.org/t/set-forget-gate-bias-of-lstm/1745 + for l in range(len(self.dec.decoder)): + set_forget_bias_to_one(self.dec.decoder[l].bias_ih) + else: + self.dec.embed.weight.data.normal_(0, 1) + + + def forward(self, xs_pad, ilens, ys_pad): + """E2E forward + + Args: + xs_pad (torch.Tensor): batch of padded input sequences (B, Tmax, idim) + ilens (torch.Tensor): batch of lengths of input sequences (B) + ys_pad (torch.Tensor): batch of padded character id sequence tensor (B, Lmax) + + Returns: + loss (torch.Tensor): transducer loss value + """ + + # 0. Frontend + if self.frontend is not None: + hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens) + hs_pad, hlens = self.feature_transform(hs_pad, hlens) + else: + hs_pad, hlens = xs_pad, ilens + + # 1. encoder + hs_pad, hlens, _ = self.enc(hs_pad, hlens) + + # 2. decoder + loss = self.dec(hs_pad, hlens, ys_pad) + + # 3. compute cer/wer + ## Note: not recommended outside debugging right now, + ## the training time is hugely impacted. + if self.training or not (self.report_cer or self.report_wer): + cer, wer = 0.0, 0.0 + else: + word_eds, word_ref_lens, char_eds, char_ref_lens = [], [], [], [] + + batchsize = int(hs_pad.size(0)) + batch_nbest = [] + + for b in six.moves.range(batchsize): + nbest_hyps = self.dec.recognize_beam(hs_pad[b], self.recog_args) + batch_nbest.append(nbest_hyps) + + y_hats = [nbest_hyp[0]['yseq'][1:] for nbest_hyp in batch_nbest] + + for i, y_hat in enumerate(y_hats): + y_true = ys_pad[i] + + seq_hat = [self.char_list[int(idx)] for idx in y_hat] + seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1] + seq_hat_text = "".join(seq_hat).replace(self.recog_args.space, ' ') + seq_true_text = "".join(seq_true).replace(self.recog_args.space, ' ') + + hyp_words = seq_hat_text.split() + ref_words = seq_true_text.split() + word_eds.append(editdistance.eval(hyp_words, ref_words)) + word_ref_lens.append(len(ref_words)) + + hyp_chars = seq_hat_text.replace(' ', '') + ref_chars = seq_true_text.replace(' ', '') + char_eds.append(editdistance.eval(hyp_chars, ref_chars)) + char_ref_lens.append(len(ref_chars)) + wer = 0.0 if not self.report_wer else float(sum(word_eds)) / sum(word_ref_lens) + cer = 0.0 if not self.report_cer else float(sum(char_eds)) / sum(char_ref_lens) + + self.loss = loss + loss_data = float(self.loss) + + if not math.isnan(loss_data): + self.reporter.report(loss_data, cer, wer) + else: + logging.warning('loss (=%f) is not correct', loss_data) + + return self.loss + + def recognize(self, x, recog_args, char_list, rnnlm=None): + """E2E recognize + + Args: + x (ndarray): input acoustic feature (T, D) + recog_args (namespace): argument Namespace containing options + char_list (list): list of characters + rnnlm (torch.nn.Module): language model module + + Returns: + y (list): n-best decoding results + """ + + prev = self.training + self.eval() + ilens = [x.shape[0]] + + # subsample frame + x = x[::self.subsample[0], :] + h = to_device(self, to_torch_tensor(x).float()) + # make a utt list (1) to use the same interface for encoder + hs = h.contiguous().unsqueeze(0) + + # 0. Frontend + if self.frontend is not None: + enhanced, hlens, mask = self.frontend(hs, ilens) + hs, hlens = self.feature_transform(enhanced, hlens) + else: + hs, hlens = hs, ilens + + # 1. Encoder + h, _, _ = self.enc(hs, hlens) + + # 2. Decoder (pred+joint or att-dec+joint) + if recog_args.search_type == 'greedy': + y = self.dec.recognize(h[0], recog_args) + else: + y = self.dec.recognize_beam(h[0], recog_args, rnnlm) + + if prev: + self.train() + + return y + + def recognize_batch(self, xs, recog_args, char_list, rnnlm=None): + """E2E recognize with batch + + Args: + xs (ndarray): list of input acoustic feature arrays [(T_1, D), (T_2, D), ...] + recog_args (namespace): argument Namespace containing options + char_list (list): list of characters + rnnlm (torch.nn.Module): language model module + + Returns: + y (list): n-best decoding results + """ + + prev = self.training + self.eval() + ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64) + + # Subsample frame + xs = [xx[::self.subsample[0], :] for xx in xs] + xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs] + xs_pad = pad_list(xs, 0.0) + + # 0. Frontend + if self.frontend is not None: + enhanced, hlens, mask = self.frontend(xs_pad, ilens) + hs_pad, hlens = self.feature_transform(enhanced, hlens) + else: + hs_pad, hlens = xs_pad, ilens + + # 1. Encoder + hs_pad, hlens, _ = self.enc(hs_pad, hlens) + + # 2. Decoder + hlens = torch.tensor(list(map(int, hlens))) + + if recog_args.search_type == 'greedy': + y = self.dec.recognize_batch(hs_pad, hlens, recog_args) + else: + y = self.dec.recognize_beam_batch(hs_pad, hlens, recog_args, rnnlm) + + if prev: + self.train() + + return y + + def enhance(self, xs): + """Forwarding only the frontend stage + + Args: + xs (ndarray): input acoustic feature (T, C, F) + + Returns: + enhanced (ndarray): + mask (torch.Tensor): + ilens (torch.Tensor): batch of lengths of input sequences (B) + """ + + if self.frontend is None: + raise RuntimeError('Frontend does\'t exist') + prev = self.training + self.eval() + ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64) + + # subsample frame + xs = [xx[::self.subsample[0], :] for xx in xs] + xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs] + xs_pad = pad_list(xs, 0.0) + enhanced, hlensm, mask = self.frontend(xs_pad, ilens) + + if prev: + self.train() + + return enhanced.cpu().numpy(), mask.cpu().numpy(), ilens + + def calculate_all_attentions(self, xs_pad, ilens, ys_pad): + """E2E attention calculation + + Args: + xs_pad (torch.Tensor): batch of padded input sequences (B, Tmax, idim) + ilens (torch.Tensor): batch of lengths of input sequences (B) + ys_pad (torch.Tensor): batch of padded character id sequence tensor (B, Lmax) + + Returns: + att_ws (ndarray): attention weights with the following shape, + 1) multi-head case => attention weights (B, H, Lmax, Tmax), + 2) other case => attention weights (B, Lmax, Tmax). + """ + + with torch.no_grad(): + # 0. Frontend + if self.frontend is not None: + hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens) + hs_pad, hlens = self.feature_transform(hs_pad, hlens) + else: + hs_pad, hlens = xs_pad, ilens + + # encoder + hpad, hlens, _ = self.enc(hs_pad, hlens) + + # decoder + att_ws = self.dec.calculate_all_attentions(hpad, hlens, ys_pad) + + return att_ws + + def subsample_frames(self, x): + # subsample frame + x = x[::self.subsample[0], :] + ilen = [x.shape[0]] + h = to_device(self, torch.from_numpy( + np.array(x, dtype=np.float32))) + h.contiguous() + return h, ilen diff --git a/espnet/nets/pytorch_backend/rnn/decoders_rnnt.py b/espnet/nets/pytorch_backend/rnn/decoders_rnnt.py new file mode 100644 index 00000000000..42f699e9571 --- /dev/null +++ b/espnet/nets/pytorch_backend/rnn/decoders_rnnt.py @@ -0,0 +1,635 @@ +from distutils.version import LooseVersion +import logging +import random +import six + +import numpy as np +import torch +import torch.nn.functional as F +import math + +from torch.nn.utils.rnn import pack_padded_sequence +from torch.nn.utils.rnn import pad_packed_sequence + +from espnet.nets.pytorch_backend.rnn.attentions import att_to_numpy + +from espnet.nets.pytorch_backend.nets_utils import pad_list +from espnet.nets.pytorch_backend.nets_utils import to_device + +class DecoderRNNT(torch.nn.Module): + """RNN-T Decoder module + + Args: + eprojs (int): # encoder projection units + odim (int): dimension of outputs + dtype (str): gru or lstm + dlayers (int): # prediction layers + dunits (int): # prediction units + sos (int): start/end of sentence symbol id + joint_dim (int): dimension of joint space + dropout (float): dropout rate + dropout_embed (float): embedding dropout rate + rnnt_type (str): type of rnn-t implementation + """ + + def __init__(self, eprojs, odim, dtype, dlayers, dunits, sos, joint_dim, embed_dim, + dropout=0.0, dropout_embed=0.0, rnnt_type='warp-transducer'): + super(DecoderRNNT, self).__init__() + + self.embed = torch.nn.Embedding(odim, embed_dim, + padding_idx=sos) + self.dropout_embed = torch.nn.Dropout(p=dropout_embed) + + if dtype == "lstm": + dec_net = torch.nn.LSTMCell + else: + dec_net = torch.nn.GRUCell + + self.decoder = torch.nn.ModuleList([dec_net(embed_dim, dunits)]) + self.dropout_dec = torch.nn.ModuleList([torch.nn.Dropout(p=dropout)]) + + for _ in six.moves.range(1, dlayers): + self.decoder += [dec_net(dunits, dunits)] + self.dropout_dec += [torch.nn.Dropout(p=dropout)] + + if rnnt_type == 'warp-transducer': + from warprnnt_pytorch import RNNTLoss + + self.rnnt_loss = RNNTLoss(blank=sos) + else: + raise NotImplementedError + + self.lin_enc = torch.nn.Linear(eprojs, joint_dim, bias=True) + self.lin_dec = torch.nn.Linear(dunits, joint_dim, bias=False) + self.lin_out = torch.nn.Linear(joint_dim, odim) + + self.dlayers = dlayers + self.dunits = dunits + self.dtype = dtype + self.joint_dim = joint_dim + self.odim = odim + + self.rnnt_type = rnnt_type + + self.ignore_id = -1 + self.sos = sos + + def zero_state(self, h_pad): + return h_pad.new_zeros(h_pad.size(0), self.dunits) + + def rnn_forward(self, ey, z_list, c_list, z_prev, c_prev): + if self.dtype == "lstm": + z_list[0], c_list[0] = self.decoder[0](ey, (z_prev[0], c_prev[0])) + + for l in six.moves.range(1, self.dlayers): + z_list[l], c_list[l] = self.decoder[l]( + self.dropout_dec[l - 1](z_list[l - 1]), (z_prev[l], c_prev[l])) + else: + z_list[0] = self.decoder[0](ey, z_prev[0]) + + for l in six.moves.range(1, self.dlayers): + z_list[l] = self.decoder[l](self.dropout_dec[l - 1](z_list[l - 1]), z_prev[l]) + + return z_list, c_list + + def joint(self, h_enc, h_dec): + """ Joint computation of z + + Args: + h_enc (torch.Tensor): batch of expanded hidden state (B, T, 1, Henc) + h_dec (torch.Tensor): batch of expanded hidden state (B, 1, U, Hdec) + + Returns: + z (torch.Tensor): output (B, T, U, odim) + """ + + z = torch.tanh(self.lin_enc(h_enc) + self.lin_dec(h_dec)) + z = self.lin_out(z) + + return z + + def forward(self, hs_pad, hlens, ys_pad): + """Decoder forward + + Args: + hs_pad (torch.Tensor): batch of padded hidden state sequences (B, Tmax, D) + hlens (torch.Tensor): batch of lengths of hidden state sequences (B) + ys_pad (torch.Tensor): batch of padded character id sequence tensor (B, Lmax) + + Returns: + loss (float): rnnt loss value + """ + + ys = [y[y != self.ignore_id] for y in ys_pad] + + hlens = list(map(int, hlens)) + + sos = ys[0].new([self.sos]) + ys_in = [torch.cat([sos, y], dim=0) for y in ys] + ys_in_pad = pad_list(ys_in, self.sos) + + olength = ys_in_pad.size(1) + + c_list = [self.zero_state(hs_pad)] + z_list = [self.zero_state(hs_pad)] + for _ in six.moves.range(1, self.dlayers): + c_list.append(self.zero_state(hs_pad)) + z_list.append(self.zero_state(hs_pad)) + + eys = self.dropout_embed(self.embed(ys_in_pad)) + + z_all = [] + for i in six.moves.range(olength): + z_list, c_list = self.rnn_forward(eys[:, i, :], z_list, c_list, + z_list, c_list) + z_all.append(self.dropout_dec[-1](z_list[-1])) + + h_dec = torch.stack(z_all, dim=1) + + h_enc = hs_pad.unsqueeze(2) + h_dec = h_dec.unsqueeze(1) + + z = self.joint(h_enc, h_dec) + y = pad_list(ys, self.sos).type(torch.int32) + + z_len = to_device(self, torch.IntTensor(hlens)) + y_len = to_device(self, torch.IntTensor([y.size(0) for y in ys])) + + loss = to_device(self, self.rnnt_loss(z, y, z_len, y_len)) + + return loss + + def recognize(self, h, recog_args): + """Greedy search implementation + + Args: + h (torch.Tensor): encoder hidden state sequences (Tmax, Henc) + recog_args (Namespace): argument Namespace containing options + + Returns: + hyp (list of dicts): 1-best decoding results + """ + + c_list = [self.zero_state(h.unsqueeze(0))] + z_list = [self.zero_state(h.unsqueeze(0))] + for _ in six.moves.range(1, self.dlayers): + c_list.append(self.zero_state(h.unsqueeze(0))) + z_list.append(self.zero_state(h.unsqueeze(0))) + + hyp = {'score': 0.0, 'yseq': [self.sos]} + + ey = torch.zeros((1, self.dunits)) + z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list) + y = self.dropout_dec[-1](z_list[-1]) + + for hi in h: + ytu = F.log_softmax(self.joint(hi, y[0]), dim=0) + logp, pred = torch.max(ytu, dim=0) + + if pred != self.sos: + hyp['yseq'].append(int(pred)) + hyp['score'] += float(logp) + + eys = torch.full((1,1), hyp['yseq'][-1], dtype=torch.long) + ey = self.dropout_embed(self.embed(eys)) + + z_list, c_list = self.rnn_forward(ey[0], z_list, c_list, z_list, c_list) + y = self.dropout_dec[-1](z_list[-1]) + + return [hyp] + + def recognize_beam(self, h, recog_args, rnnlm=None): + """Beam search implementation + + Args: + h (torch.Tensor): encoder hidden state sequences (Tmax, Henc) + recog_args (Namespace): argument Namespace containing options + rnnlm (torch.nn.Module): language module + + Returns: + nbest_hyps (list of dicts): n-best decoding results + """ + + beam = recog_args.beam_size + normscore = recog_args.score_norm + k_range = min(beam, self.odim) + nbest = recog_args.nbest + + c_list = [self.zero_state(h.unsqueeze(0))] + z_list = [self.zero_state(h.unsqueeze(0))] + for _ in six.moves.range(1, self.dlayers): + c_list.append(self.zero_state(h.unsqueeze(0))) + z_list.append(self.zero_state(h.unsqueeze(0))) + + if rnnlm: + final_hyps = [{'score': 0.0, 'yseq': [self.sos], 'z_prev': z_list, 'c_prev': c_list, + 'lm_state': None}] + else: + final_hyps = [{'score': 0.0, 'yseq': [self.sos], 'z_prev': z_list, 'c_prev': c_list}] + + for hi in h: + hyps = final_hyps + final_hyps = [] + + while True: + new_hyp = max(hyps, key=lambda x: x['score']) + hyps.remove(new_hyp) + + vy = to_device(self, torch.full((1,1), new_hyp['yseq'][-1], dtype=torch.long)) + ey = self.dropout_embed(self.embed(vy))[0] + + z_list, c_list = self.rnn_forward(ey, z_list, c_list, + new_hyp['z_prev'], new_hyp['c_prev']) + y = self.dropout_dec[-1](z_list[-1])[0] + + ytu = F.log_softmax(self.joint(hi, y), dim=0) + + if rnnlm: + rnnlm_state, rnnlm_scores = rnnlm.predict(new_hyp['lm_state'], vy[0]) + #ytu += recog_args.lm_weight * rnnlm_scores[0] + + logp, pred = torch.topk(ytu, k=k_range, dim=0) + + for k in six.moves.range(k_range): + beam_hyp = {'score': new_hyp['score'] + logp[k], + 'yseq': new_hyp['yseq'][:], + 'z_prev': new_hyp['z_prev'][:], + 'c_prev': new_hyp['c_prev'][:]} + if rnnlm: + beam_hyp['lm_state'] = new_hyp['lm_state'] + + if pred[k] == self.sos: + final_hyps.append(beam_hyp) + else: + beam_hyp['z_prev'] = z_list[:] + beam_hyp['c_prev'] = c_list[:] + beam_hyp['yseq'].append(pred[k]) + + if rnnlm: + beam_hyp['lm_state'] = rnnlm_state + beam_hyp['score'] += recog_args.lm_weight * rnnlm_scores[0][pred[k]] + + hyps.append(beam_hyp) + + if len(final_hyps) >= k_range: + break + + if normscore: + nbest_hyps = sorted( + final_hyps, key=lambda x: x['score'] / len(x['yseq']), reverse=True)[:nbest] + else: + nbest_hyps = sorted( + final_hyps, key=lambda x: x['score'], reverse=True)[:nbest] + + return nbest_hyps + +class DecoderRNNTAtt(torch.nn.Module): + """RNNT-Att Decoder module + + Args: + eprojs (int): # encoder projection units + odim (int): dimension of outputs + dtype (str): gru or lstm + dlayers (int): # decoder layers + dunits (int): # decoder units + sos (int): start of sequence symbol id + att (torch.nn.Module): attention module + joint_dim (int): dimension of joint space + dropout (float): dropout rate + dropout_embed (float): embedding dropout rate + rnnt_type (str): type of rnnt implementation + """ + + def __init__(self, eprojs, odim, dtype, dlayers, dunits, sos, att, joint_dim, + dropout=0.0, dropout_embed=0.0, rnnt_type='warp-transducer'): + super(DecoderRNNTAtt, self).__init__() + + self.embed = torch.nn.Embedding(odim, dunits, padding_idx=sos) + self.dropout_emb = torch.nn.Dropout(p=dropout_embed) + + if dtype == "lstm": + dec_net = torch.nn.LSTMCell + else: + dec_net = torch.nn.GRUCell + + self.decoder = torch.nn.ModuleList([dec_net((dunits + eprojs), dunits)]) + self.dropout_dec = torch.nn.ModuleList([torch.nn.Dropout(p=dropout)]) + + for _ in six.moves.range(1, dlayers): + self.decoder += [dec_net(dunits, dunits)] + self.dropout_dec += [torch.nn.Dropout(p=dropout)] + + if rnnt_type == 'warp-transducer': + from warprnnt_pytorch import RNNTLoss + + self.rnnt_loss = RNNTLoss(blank=sos) + else: + raise NotImplementedError + + self.lin_enc = torch.nn.Linear(eprojs, joint_dim, bias=True) + self.lin_dec = torch.nn.Linear(dunits, joint_dim, bias=False) + self.lin_out = torch.nn.Linear(joint_dim, odim) + + self.att = att + + self.dtype = dtype + self.dlayers = dlayers + self.dunits = dunits + self.joint_dim = joint_dim + self.odim = odim + + self.rnnt_type = rnnt_type + + self.ignore_id = -1 + self.sos = sos + + def zero_state(self, h_pad): + return h_pad.new_zeros(h_pad.size(0), self.dunits) + + def rnn_forward(self, ey, z_list, c_list, z_prev, c_prev): + if self.dtype == "lstm": + z_list[0], c_list[0] = self.decoder[0](ey, (z_prev[0], c_prev[0])) + + for l in six.moves.range(1, self.dlayers): + z_list[l], c_list[l] = self.decoder[l]( + self.dropout_dec[l - 1](z_list[l - 1]), (z_prev[l], c_prev[l])) + else: + z_list[0] = self.decoder[0](ey, z_prev[0]) + + for l in six.moves.range(1, self.dlayers): + z_list[l] = self.decoder[l](self.dropout_dec[l - 1](z_list[l - 1]), z_prev[l]) + return z_list, c_list + + def joint(self, h_enc, h_dec): + """Joint computation of z + + Args: + h_enc (torch.Tensor): batch of expanded hidden state (B, T, 1, Henc) + h_dec (torch.Tensor): batch of expanded hidden state (B, 1, U, Hdec) + + Returns: + z (torch.Tensor): output (B, T, U, odim) + """ + + z = torch.tanh(self.lin_enc(h_enc) + self.lin_dec(h_dec)) + z = self.lin_out(z) + + return z + + def forward(self, hs_pad, hlens, ys_pad): + """Decoder forward + + Args: + hs_pad (torch.Tensor): batch of padded hidden state sequences (B, Tmax, D) + hlens (torch.Tensor): batch of lengths of hidden state sequences (B) + ys_pad (torch.Tensor): batch of padded character id sequence tensor (B, Lmax) + + Returns: + loss (torch.Tensor): rnnt-att loss value + """ + + ys = [y[y != self.ignore_id] for y in ys_pad] + + hlens = list(map(int, hlens)) + + sos = ys[0].new([self.sos]) + ys_in = [torch.cat([sos, y], dim=0) for y in ys] + ys_in_pad = pad_list(ys_in, self.sos) + + olength = ys_in_pad.size(1) + + c_list = [self.zero_state(hs_pad)] + z_list = [self.zero_state(hs_pad)] + for _ in six.moves.range(1, self.dlayers): + c_list.append(self.zero_state(hs_pad)) + z_list.append(self.zero_state(hs_pad)) + + att_w = None + self.att[0].reset() + + eys = self.dropout_emb(self.embed(ys_in_pad)) + + z_all = [] + for i in six.moves.range(olength): + att_c, att_w = self.att[0](hs_pad, hlens, self.dropout_dec[0](z_list[0]), att_w) + + ey = torch.cat((eys[:, i, :], att_c), dim=1) + z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list) + + z_all.append(self.dropout_dec[-1](z_list[-1])) + + h_dec = torch.stack(z_all, dim=1) + + h_enc = hs_pad.unsqueeze(2) + h_dec = h_dec.unsqueeze(1) + + z = self.joint(h_enc, h_dec) + y = pad_list(ys, self.sos).type(torch.int32) + + z_len = to_device(self, torch.IntTensor(hlens)) + y_len = to_device(self, torch.IntTensor([y.size(0) for y in ys])) + + loss = to_device(self, self.rnnt_loss(z, y, z_len, y_len)) + + return loss + + def recognize(self, h, recog_args): + """Greedy search implementation + + Args: + h (torch.Tensor): encoder hidden state sequences (Tmax, Henc) + recog_args (Namespace): argument Namespace containing options + + Returns: + hyp (list of dicts): 1-best decoding results + """ + + self.att[0].reset() + + c_list = [self.zero_state(h.unsqueeze(0))] + z_list = [self.zero_state(h.unsqueeze(0))] + + for _ in six.moves.range(1, self.dlayers): + c_list.append(self.zero_state(h.unsqueeze(0))) + z_list.append(self.zero_state(h.unsqueeze(0))) + + hyp = {'score': 0.0, 'yseq': [self.sos]} + + eys = torch.zeros((1, self.dunits)) + att_c, att_w = self.att[0](h.unsqueeze(0), [h.size(0)], + self.dropout_dec[0](z_list[0]), None) + ey = torch.cat((eys, att_c), dim=1) + + z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list) + y = self.dropout_dec[-1](z_list[-1]) + + for hi in h: + ytu = F.log_softmax(self.joint(hi, y[0]), dim=0) + logp, pred = torch.max(ytu, dim=0) + + if pred != self.sos: + hyp['yseq'].append(int(pred)) + hyp['score'] += float(logp) + + eys = torch.full((1,1), hyp['yseq'][-1], dtype=torch.long) + ey = self.dropout_emb(self.embed(eys)) + att_c, att_w = self.att[0](h.unsqueeze(0), [h.size(0)], + self.dropout_dec[0](z_list[0]), + att_w) + ey = torch.cat((ey[0], att_c), dim=1) + + z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list) + y = self.dropout_dec[-1](z_list[-1]) + + return [hyp] + + def recognize_beam(self, h, recog_args, rnnlm=None): + """Beam search recognition + + Args: + h (torch.Tensor): encoder hidden state sequences (Tmax, Henc) + recog_args (Namespace): argument Namespace containing options + rnnlm (torch.nn.Module): language module + + Results: + nbest_hyps (list of dicts): n-best decoding results + """ + + beam = recog_args.beam_size + normscore = recog_args.score_norm + k_range = min(beam, self.odim) + nbest = recog_args.nbest + + self.att[0].reset() + + c_list = [self.zero_state(h.unsqueeze(0))] + z_list = [self.zero_state(h.unsqueeze(0))] + for _ in six.moves.range(1, self.dlayers): + c_list.append(self.zero_state(h.unsqueeze(0))) + z_list.append(self.zero_state(h.unsqueeze(0))) + + if rnnlm: + final_hyps = [{'score': 0.0, 'yseq': [self.sos], 'z_prev': z_list, 'c_prev': c_list, + 'a_prev': None, 'lm_state': None}] + else: + final_hyps = [{'score': 0.0, 'yseq': [self.sos], 'z_prev': z_list, 'c_prev': c_list, + 'a_prev': None}] + for hi in h: + hyps = final_hyps + final_hyps = [] + + while True: + new_hyp = max(hyps, key=lambda x: x['score']) + hyps.remove(new_hyp) + + vy = to_device(self, torch.full((1,1), new_hyp['yseq'][-1], dtype=torch.long)) + ey = self.dropout_emb(self.embed(vy)) + + att_c, att_w = self.att[0](h.unsqueeze(0), [h.size(0)], + self.dropout_dec[0](new_hyp['z_prev'][0]), + new_hyp['a_prev']) + + ey = torch.cat((ey[0], att_c), dim=1) + z_list, c_list = self.rnn_forward(ey, z_list, c_list, + new_hyp['z_prev'], new_hyp['c_prev']) + y = self.dropout_dec[-1](z_list[-1]) + ytu = F.log_softmax(self.joint(hi, y[0]), dim=0) + + if rnnlm: + rnnlm_state, rnnlm_scores = rnnlm.predict(new_hyp['lm_state'], vy[0]) + + logp, pred = torch.topk(ytu, k=k_range, dim=0) + + for k in six.moves.range(k_range): + beam_hyp = {'score': new_hyp['score'] + logp[k], + 'yseq': new_hyp['yseq'][:], + 'z_prev': new_hyp['z_prev'][:], + 'c_prev': new_hyp['c_prev'][:], + 'a_prev': new_hyp['a_prev']} + if rnnlm: + beam_hyp['lm_state'] = new_hyp['lm_state'] + + if pred[k] == self.sos: + final_hyps.append(beam_hyp) + else: + beam_hyp['z_prev'] = z_list[:] + beam_hyp['c_prev'] = c_list[:] + beam_hyp['a_prev'] = att_w[:] + beam_hyp['yseq'].append(pred[k]) + + if rnnlm: + beam_hyp['lm_state'] = rnnlm_state + beam_hyp['score'] += recog_args.lm_weight * rnnlm_scores[0][pred[k]] + + hyps.append(beam_hyp) + + if len(final_hyps) >= k_range: + break + + if normscore: + nbest_hyps = sorted( + final_hyps, key=lambda x: x['score'] / len(x['yseq']), reverse=True)[:nbest] + else: + nbest_hyps = sorted( + final_hyps, key=lambda x: x['score'], reverse=True)[:nbest] + + return nbest_hyps + + def calculate_all_attentions(self, hs_pad, hlen, ys_pad): + """Calculate all of attentions + + Args: + hs_pad (torch.Tensor): batch of padded hidden state sequences (B, Tmax, D) + hlen (torch.Tensor): batch of lengths of hidden state sequences (B) + ys_pad (torch.Tensor): batch of padded character id sequence tensor (B, Lmax) + + Returns: + att_ws (ndarray): attention weights with the following shape, + 1) multi-head case => attention weights (B, H, Lmax, Tmax), + 2) other case => attention weights (B, Lmax, Tmax). + """ + + ys = [y[y != self.ignore_id] for y in ys_pad] + hlen = list(map(int, hlen)) + + sos = ys[0].new([self.sos]) + + ys_in = [torch.cat([sos, y], dim=0) for y in ys] + ys_in_pad = pad_list(ys_in, self.sos) + + olength = ys_in_pad.size(1) + + c_list = [self.zero_state(hs_pad)] + z_list = [self.zero_state(hs_pad)] + for _ in six.moves.range(1, self.dlayers): + c_list.append(self.zero_state(hs_pad)) + z_list.append(self.zero_state(hs_pad)) + + att_w = None + att_ws = [] + self.att[0].reset() + + eys = self.dropout_emb(self.embed(ys_in_pad)) + + for i in six.moves.range(olength): + att_c, att_w = self.att[0](hs_pad, hlen, self.dropout_dec[0](z_list[0]), att_w) + ey = torch.cat((eys[:, i, :], att_c), dim=1) + z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list) + + att_ws.append(att_w) + + att_ws = att_to_numpy(att_ws, self.att[0]) + + return att_ws + +def decoder_for(args, odim, sos, att=None): + if args.rnnt_mode == 0: + return DecoderRNNT(args.eprojs, odim, args.dtype, args.dlayers, args.dunits, + sos, args.joint_dim, args.dec_embed_dim, + args.dropout_rate_decoder, args.dropout_rate_embed_decoder, + args.rnnt_type) + elif args.rnnt_mode == 1: + return DecoderRNNTAtt(args.eprojs, odim, args.dtype, args.dlayers, args.dunits, + sos, att, args.joint_dim, + args.dropout_rate_decoder, args.dropout_rate_embed_decoder, + args.rnnt_type) diff --git a/tools/Makefile b/tools/Makefile index 311451a04db..bc2ba576c2c 100644 --- a/tools/Makefile +++ b/tools/Makefile @@ -30,7 +30,7 @@ endif all: kaldi.done python check_install -python: venv $(CUDA_DEPS) warp-ctc.done chainer_ctc.done +python: venv $(CUDA_DEPS) warp-ctc.done warp-transducer.done chainer_ctc.done extra: nkf.done sentencepiece.done mecab.done moses.done mwerSegmenter.done pesq @@ -97,6 +97,17 @@ warp-ctc.done: espnet.done . venv/bin/activate; cd warp-ctc/pytorch_binding && python setup.py install # maybe need to: apt-get install python-dev touch warp-ctc.done +warp-transducer.done: espnet.done + rm -rf warp-transducer + git clone https://github.com/HawkAaron/warp-transducer.git + # Note: Requires gcc>=4.9 to build extensions with pytorch>=1.0 + if . venv/bin/activate && python -c 'import torch as t;assert t.__version__[0] == "1"' &> /dev/null; then \ + . venv/bin/activate && python -c "from distutils.version import LooseVersion as V;assert V('$(GCC_VERSION)') >= V('4.9'), 'Requires gcc>=4.9'"; \ + fi + . venv/bin/activate; cd warp-transducer && mkdir build && cd build && cmake .. && make; true + . venv/bin/activate; export WARP_RNNT_PATH=$(PWD)/warp-transducer/build/ && cd warp-transducer/pytorch_binding && python setup.py install + touch warp-transducer.done + chainer_ctc.done: espnet.done rm -rf chainer_ctc git clone https://github.com/jheymann85/chainer_ctc.git @@ -177,13 +188,13 @@ PESQ.zip: clean: clean_extra - rm -rf kaldi venv warp-ctc chainer_ctc + rm -rf kaldi venv warp-ctc warp-transducer chainer_ctc rm -f miniconda.sh rm -rf *.done find . -iname "*.pyc" -delete clean_python: - rm -rf venv warp-ctc chainer_ctc + rm -rf venv warp-ctc warp-transducer chainer_ctc rm -f miniconda.sh rm -f warp-ctc.done chainer_ctc.done espnet.done pytorch_complex pytorch_complex.done find . -iname "*.pyc" -delete diff --git a/tools/check_install.py b/tools/check_install.py index ac1ececcf32..ff5376c0dd8 100644 --- a/tools/check_install.py +++ b/tools/check_install.py @@ -23,7 +23,8 @@ def main(args): ('torch', ("0.4.1", "1.0.0", "1.0.1.post2")), ('chainer', ("6.0.0")), ('chainer_ctc', None), - ('warpctc_pytorch', ("0.1.1")) + ('warpctc_pytorch', ("0.1.1")), + ('warprnnt_pytorch', ("0.1.1")) ] if not args.no_cupy: @@ -65,9 +66,15 @@ def main(args): is_correct_version_list = [] for idx, (name, version) in enumerate(library_list): if version is not None: - lib = importlib.import_module(name) - if hasattr(lib, "__version__"): - is_correct = lib.__version__ in version + # Note: temp. fix for warprnnt_pytorch + # not found version with importlib + if name == "warprnnt_pytorch": + import pkg_resources + vers = pkg_resources.get_distribution(name).version + else: + vers = importlib.import_module(name).__version__ + if vers != None: + is_correct = vers in version if is_correct: logging.info("--> %s version is matched." % name) is_correct_version_list.append(True)