Skip to content

Commit

Permalink
RNN-T and RNN-T-Att for ESPNET v0.5.0
Browse files Browse the repository at this point in the history
  • Loading branch information
b-flo committed Aug 8, 2019
1 parent ea313cd commit 7dc88cd
Show file tree
Hide file tree
Showing 14 changed files with 2,576 additions and 12 deletions.
1 change: 1 addition & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ tools/miniconda.sh
tools/nkf/
tools/venv/
tools/warp-ctc/
tools/warp-transducer/
tools/chainer_ctc/
tools/subword-nmt/

Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,6 @@ tools/venv/
tools/sentencepiece/
tools/swig/
tools/warp-ctc/
tools/warp-transducer/
tools/*.done
tools/PESQ*
4 changes: 4 additions & 0 deletions egs/voxforge/asr1/conf/tuning/decode_rnnt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# decoding parameters
beam-size: 20
search-type: beam
score-norm: True
46 changes: 46 additions & 0 deletions egs/voxforge/asr1/conf/tuning/train_rnnt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# minibatch related
batch-size: 20
maxlen-in: 800
maxlen-out: 150

# optimization related
sortagrad: 0
opt: adadelta
epochs: 20
patience: 3

# network architecture
## encoder related
etype: vggblstmp
elayers: 4
eunits: 320
eprojs: 320
## decoder related
dtype: lstm
dlayers: 1
dunits: 320
dec-embed-dim: 320
## joint network related
joint-dim: 320

# rnn-t related (0:rnnt, 1:rnnt-att)
rnnt-mode: 0

# finetuning related
## Note : Current implementation only allow to do pre-initialization with models
## matching the configuration specified above. The architecture you specify
## should match the modules architecture you want to do finetuning with.
## For example, if you want to pre-initialize the decoder embedding layer
## from a specified model, "dec-embed-dim" param should be set accordingly.

# following model correspond to conf/tuning/train_mtlalpha1.0.yaml
#enc-init: "exp/tr_it_pytorch_train_mtlalpha1.0/results/model.loss.best"
#enc-init-mods: "enc.enc."

# following model is a CE-trained RNNLM similar to the one in:
# egs/librispeech/asr1/conf/tuning/lm.yaml
#dec-init: "exp/train_rnnlm_pytorch/rnnlm.model.best"
#dec-init-mods: "predictor.rnn.,predictor.embed."

# freeze modules
#freeze-modules: "predictor.rnn."
43 changes: 43 additions & 0 deletions egs/voxforge/asr1/conf/tuning/train_rnnt_att.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# minibatch related
batch-size: 20
maxlen-in: 800
maxlen-out: 150

# optimization related
sortagrad: 0
opt: adadelta
epochs: 20
patience: 3

# network architecture
## encoder related
etype: vggblstmp
elayers: 4
eunits: 320
eprojs: 320
## decoder related
dlayers: 1
dunits: 300
## attention related
atype: location
adim: 320
aconv-chans: 10
aconv-filts: 100
## joint network related
joint-dim: 320

# rnn-t related (0:rnnt, 1:rnnt-att)
rnnt-mode: 1

# finetuning related
## Note : Current implementation only allow to do pre-initialization with models
## matching the configuration specified above. The architecture you specify
## should match the modules architecture you want to do finetuning with.
## For example, if you want to pre-initialize the decoder embedding layer
## from a specified model, "dec-embed-dim" param should be set accordingly.

# following model correspond to conf/tuning/train_mtlalpha1.0.yaml
#enc-init: "exp/tr_it_pytorch_train_mtlalpha1.0/results/model.loss.best"

# following model correspond to conf/tuning/train_mtlalpha0.0.yaml
#dec-init: "exp/tr_it_pytorch_train_mtlalpha0.0/results/model.loss.best"
7 changes: 2 additions & 5 deletions egs/voxforge/asr1/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)

. ./path.sh
. ./cmd.sh
. ./path.sh || exit 1;
. ./cmd.sh || exit 1;

# general configuration
backend=pytorch
Expand Down Expand Up @@ -36,9 +36,6 @@ tag="" # tag for managing experiments.

. utils/parse_options.sh || exit 1;

. ./path.sh
. ./cmd.sh

# Set bash to 'debug' mode, it will exit on :
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
set -e
Expand Down
152 changes: 152 additions & 0 deletions espnet/asr/asr_rnnt_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import os
import logging
import torch

from collections import OrderedDict

# Note : the following methods are just toy examples for finetuning RNN-T and RNN-T att.
# It's quite inelegant as if but it handle most of the cases.

def load_pretrained_modules(model, rnnt_mode, enc_pretrained, dec_pretrained,
enc_mods, dec_mods):
"""Method to update model modules weights from up to two ESPNET pre-trained models.
Specified models can be either trained with CTC, attention or joint CTC-attention,
or a language model trained with CE to initialize the decoder part in vanilla RNN-T.
Args:
model (torch.nn.Module): initial torch model
rnnt_mode (int): RNN-transducer mode
enc_pretrained (str): ESPNET pre-trained model path file to initialize encoder part
dec_pretrained (str): ESPNET pre-trained model path file to initialize decoder part
"""

def filter_modules(model, modules):
new_mods = []
incorrect_mods = []

mods_model = list(model.keys())
for mod in modules:
if any(key.startswith(mod) for key in mods_model):
new_mods += [mod]
else:
incorrect_mods += [mod]

if incorrect_mods:
logging.info("Some specified module(s) for finetuning don\'t "
"match (or partially match) available modules."
" Disabling the following modules: %s", incorrect_mods)
logging.info('For information, the existing modules in model are:')
logging.info('%s', mods_model)

return new_mods

def validate_modules(model, prt, modules):
modules_model = []
modules_prt = []

for key_p, value_p in prt.items():
if any(key_p.startswith(m) for m in modules):
modules_prt += [(key_p, value_p.shape)]

for key_m, value_m in model.items():
if any(key_m.startswith(m) for m in modules):
modules_model += [(key_m, value_m.shape)]

len_match = (len(modules_model) == len(modules_prt))
module_match = (sorted([x for x in modules_model]) == \
sorted([x for x in modules_prt]))

return len_match and module_match

def get_am_state_dict(model, modules):
new_state_dict = OrderedDict()

for key, value in model.items():
if any(key.startswith(m) for m in modules):
if not 'output' in key:
new_state_dict[key] = value

return new_state_dict

def get_lm_state_dict(model, modules):
new_state_dict = OrderedDict()
new_modules = []

for key, value in list(model.items()):
if key == "predictor.embed.weight" \
and "predictor.embed." in modules:
new_key = "dec.embed.weight"
new_state_dict[new_key] = value
new_modules += [new_key]
elif "predictor.rnn." in key \
and "predictor.rnn." in modules:
new_key = "dec.decoder." + key.split("predictor.rnn.",1)[1]
new_state_dict[new_key] = value
new_modules += [new_key]

return new_state_dict, new_modules

model_state_dict = model.state_dict()

if rnnt_mode == 0 and dec_pretrained is not None:
lm_init = True
else:
lm_init = False

for prt_model_path, prt_mods, prt_type in [(enc_pretrained, enc_mods, False),
(dec_pretrained, dec_mods, lm_init)]:
if prt_model_path is not None:
if os.path.isfile(prt_model_path):
prt_model = torch.load(prt_model_path,
map_location=lambda storage, loc: storage)

prt_mods = filter_modules(prt_model, prt_mods)
if prt_type:
prt_state_dict, prt_mods = get_lm_state_dict(prt_model, prt_mods)
else:
prt_state_dict = get_am_state_dict(prt_model, prt_mods)

if prt_state_dict:
if validate_modules(model_state_dict, prt_state_dict, prt_mods):
model_state_dict.update(prt_state_dict)
else:
logging.info("The model you specified for pre-initialization "
"doesn\'t match your run config. It will be ignored")
logging.info('Model path: %s' % prt_model_path)
else:
logging.info('The model you specified for pre-initialization was not found.')
logging.info('Model path: %s' % prt_model_path)

model.load_state_dict(model_state_dict)

del model_state_dict

return model

def freeze_modules(model, modules):
"""Method to freeze specified list of modules.
Args:
model (torch.nn.Module): torch model
modules (str): list of module names to freeze during training
Returns:
(boolean): if True, filter the specified modules in the optimizer
"""

mods_model = [name for name, _ in model.named_parameters()]

if not any(i in j for j in mods_model for i in modules):
logging.info("Some module(s) you specified to freeze don\'t "
"match (or partially match) available modules.")
logging.info("Disabling the option.")
logging.info('For information, the existing modules in model are:')
logging.info('%s', mods_model)

return False

for name, param in model.named_parameters():
if any(name.startswith(m) for m in modules):
param.requires_grad = False

return True
Loading

0 comments on commit 7dc88cd

Please sign in to comment.