From a3edd928190022c21ffb93a6e52962b66293c40b Mon Sep 17 00:00:00 2001 From: sooftware Date: Tue, 5 Jan 2021 02:17:13 +0900 Subject: [PATCH] resolved issue #37 --- kospeech/models/extractor.py | 2 +- kospeech/models/transformer/layers.py | 18 +--- kospeech/models/transformer/mask.py | 26 ++---- kospeech/models/transformer/model.py | 119 ++++++++++++-------------- 4 files changed, 66 insertions(+), 99 deletions(-) diff --git a/kospeech/models/extractor.py b/kospeech/models/extractor.py index de7a8b55..696bf869 100644 --- a/kospeech/models/extractor.py +++ b/kospeech/models/extractor.py @@ -130,7 +130,7 @@ class VGGExtractor(CNNExtractor): "Advances in Joint CTC-Attention based End-to-End Speech Recognition with a Deep CNN Encoder and RNN-LM" paper - https://arxiv.org/pdf/1706.02737.pdf """ - def __init__(self, activation: str, mask_conv: bool): + def __init__(self, activation: str = 'hardtanh', mask_conv: bool = False): super(VGGExtractor, self).__init__(activation) self.mask_conv = mask_conv self.conv = nn.Sequential( diff --git a/kospeech/models/transformer/layers.py b/kospeech/models/transformer/layers.py index 4566c741..4ce5fca8 100644 --- a/kospeech/models/transformer/layers.py +++ b/kospeech/models/transformer/layers.py @@ -39,18 +39,9 @@ def __init__( self.self_attention = AddNorm(MultiHeadAttention(d_model, num_heads), d_model) self.feed_forward = AddNorm(PositionWiseFeedForwardNet(d_model, d_ff, dropout_p, ffnet_style), d_model) - def forward( - self, - inputs: Tensor, # B x T_input x D - non_pad_mask: Optional[Any] = None, # B x T_input - self_attn_mask: Optional[Any] = None # B x T_input x T_output - ) -> Tuple[Tensor, Tensor]: + def forward(self, inputs: Tensor, self_attn_mask: Optional[Any] = None) -> Tuple[Tensor, Tensor]: output, attn = self.self_attention(inputs, inputs, inputs, self_attn_mask) - output *= non_pad_mask - output = self.feed_forward(output) - output *= non_pad_mask - return output, attn @@ -84,17 +75,10 @@ def forward( self, inputs: Tensor, # B x T_input memory: Tensor, # B x T_input x D_model - non_pad_mask: Optional[Any] = None, # B x T_input self_attn_mask: Optional[Any] = None, # B x T_input x T_input memory_mask: Optional[Any] = None # B x T_input x T_output ) -> Tuple[Tensor, Tensor, Tensor]: output, self_attn = self.self_attention(inputs, inputs, inputs, self_attn_mask) - output *= non_pad_mask - output, memory_attn = self.memory_attention(output, memory, memory, memory_mask) - output *= non_pad_mask - output = self.feed_forward(output) - output *= non_pad_mask - return output, self_attn, memory_attn diff --git a/kospeech/models/transformer/mask.py b/kospeech/models/transformer/mask.py index c9b8cf53..134cecd0 100644 --- a/kospeech/models/transformer/mask.py +++ b/kospeech/models/transformer/mask.py @@ -9,34 +9,26 @@ from typing import Any, Optional -def get_pad_mask(inputs: Tensor, input_lengths: Optional[Any] = None, pad_id: int = None) -> Tensor: - """ - Padding position is set to True, either use input_lengths or pad_id - - Examples:: - >>> get_pad_mask(inputs, input_lengths) - tensor([[[False], [False], [False], [False], [False], [ True], [ True], [ True], [ True]], - [[False], [False], [False], [False], [False], [False], [ True], [ True], [ True]], - [[False], [False], [False], [False], [False], [False], [False], [False], [ True]]]) - """ +def get_non_pad_mask(inputs: Tensor, input_lengths: Optional[Any] = None, pad_id: int = None) -> Tensor: + """ Padding position is set to 0, either use input_lengths or pad_id """ assert (input_lengths is None and pad_id is not None) or (input_lengths is not None and pad_id is None) if input_lengths is not None: batch_size = inputs.size(0) if len(inputs.size()) == 2: - pad_mask = inputs.new_zeros(inputs.size()) # B x T + non_pad_mask = inputs.new_ones(inputs.size()) # B x T else: - pad_mask = inputs.new_zeros(inputs.size()[:-1]) # B x T + non_pad_mask = inputs.new_ones(inputs.size()[:-1]) # B x T for i in range(batch_size): - pad_mask[i, input_lengths[i]:] = 1 + non_pad_mask[i, input_lengths[i]:] = 0 if pad_id is not None: assert inputs.dim() == 2 - pad_mask = inputs.eq(pad_id) + non_pad_mask = inputs.ne(pad_id).float() - return pad_mask.unsqueeze(-1).bool() + return non_pad_mask.unsqueeze(-1) def get_decoder_self_attn_mask(seq_k: Tensor, seq_q: Tensor, pad_id): @@ -139,8 +131,8 @@ def get_attn_pad_mask(inputs, input_lengths, expand_length): """ # N x Ti x 1 - non_pad_mask = get_pad_mask(inputs, input_lengths=input_lengths).eq(False) + non_pad_mask = get_non_pad_mask(inputs, input_lengths=input_lengths) # N x Ti, lt(1) like not operation - pad_mask = non_pad_mask.squeeze(-1).eq(False) + pad_mask = non_pad_mask.squeeze(-1).lt(1) attn_mask = pad_mask.unsqueeze(1).expand(-1, expand_length, -1) return attn_mask diff --git a/kospeech/models/transformer/model.py b/kospeech/models/transformer/model.py index 453193f2..cc36acd2 100644 --- a/kospeech/models/transformer/model.py +++ b/kospeech/models/transformer/model.py @@ -13,24 +13,33 @@ import math import torch import torch.nn as nn + from torch import Tensor -from typing import Optional, Any +from typing import ( + Optional, + Any, + Tuple, + Union, +) +from kospeech.models.extractor import ( + VGGExtractor, + DeepSpeech2Extractor, +) from kospeech.models.modules import ( Linear, - LayerNorm + LayerNorm, ) from kospeech.models.transformer.mask import ( - get_pad_mask, get_attn_pad_mask, - get_decoder_self_attn_mask + get_decoder_self_attn_mask, ) from kospeech.models.transformer.embeddings import ( Embedding, - PositionalEncoding + PositionalEncoding, ) from kospeech.models.transformer.layers import ( SpeechTransformerEncoderLayer, - SpeechTransformerDecoderLayer + SpeechTransformerDecoderLayer, ) @@ -53,9 +62,13 @@ class SpeechTransformer(nn.Module): ffnet_style (str): if poswise_ffnet is 'ff', position-wise feed forware network to be a feed forward, otherwise, position-wise feed forward network to be a convolution layer. (default: ff) - Inputs: inputs, targets - - **inputs** (batch, input_length): tensor containing input sequences - - **targets** (batch, target_length): tensor contatining target sequences + Inputs: inputs, input_lengths, targets, teacher_forcing_ratio + - **inputs** (torch.Tensor): tensor of sequences, whose length is the batch size and within which + each sequence is a list of token IDs. This information is forwarded to the encoder. + - **input_lengths** (torch.Tensor): tensor of sequences, whose contains length of inputs. + - **targets** (torch.Tensor): tensor of sequences, whose length is the batch size and within which + each sequence is a list of token IDs. This information is forwarded to the decoder. + - **return_attns (bool): flag indication whether to return attention lists Returns: output - **output**: tensor containing the outputs @@ -80,37 +93,17 @@ def __init__( assert d_model % num_heads == 0, "d_model % num_heads should be zero." - if extractor.lower() == 'vgg': + self.extractor = extractor + + if self.extractor == 'vgg': input_dim = (input_dim - 1) << 5 if input_dim % 2 else input_dim << 5 - self.conv = nn.Sequential( - nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False), - nn.BatchNorm2d(num_features=64), - nn.Hardtanh(0, 20, inplace=True), - nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False), - nn.BatchNorm2d(num_features=64), - nn.Hardtanh(0, 20, inplace=True), - nn.MaxPool2d(2, stride=2), - nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False), - nn.BatchNorm2d(num_features=128), - nn.Hardtanh(0, 20, inplace=True), - nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False), - nn.BatchNorm2d(num_features=128), - nn.Hardtanh(0, 20, inplace=True), - nn.MaxPool2d(2, stride=2) - ) - - elif extractor.lower() == 'ds2': + self.conv = VGGExtractor(mask_conv=False) + + elif self.extractor == 'ds2': input_dim = int(math.floor(input_dim + 2 * 20 - 41) / 2 + 1) input_dim = int(math.floor(input_dim + 2 * 10 - 21) / 2 + 1) - input_dim <<= 5 - self.conv = nn.Sequential( - nn.Conv2d(1, 32, kernel_size=(41, 11), stride=(2, 2), padding=(20, 5), bias=False), - nn.BatchNorm2d(32), - nn.Hardtanh(0, 20, inplace=True), - nn.Conv2d(32, 32, kernel_size=(21, 11), stride=(2, 1), padding=(10, 5), bias=False), - nn.BatchNorm2d(32), - nn.Hardtanh(0, 20, inplace=True), - ) + input_dim <<= 6 + self.conv = DeepSpeech2Extractor(mask_conv=False) else: raise ValueError("Unsupported Extractor : {0}".format(extractor)) @@ -143,17 +136,19 @@ def __init__( def forward( self, - inputs: Tensor, # B x T_input x D_Feature - input_lengths: Tensor, # B - targets: Optional[Tensor] = None, # B x T_output => a b c d e . . . - return_attns: bool = False # bool - ): - conv_feat = self.conv(inputs.unsqueeze(1)) + inputs: Tensor, # tensor of input sequences + input_lengths: Tensor, # tensor of input sequence lengths + targets: Optional[Tensor] = None, # tensor of target sequences + return_attns: bool = False # flag indication whether to return attention lists + ) -> Union[Tensor, tuple]: + conv_feat = self.conv(inputs.unsqueeze(1), input_lengths) conv_feat = conv_feat.transpose(1, 2) batch_size, seq_length, num_channels, hidden_dim = conv_feat.size() conv_feat = conv_feat.contiguous().view(batch_size, seq_length, num_channels * hidden_dim) - input_lengths = (input_lengths >> 2).int() + + if self.extractor == 'vgg': + input_lengths = (input_lengths >> 2).int() memory, encoder_self_attns = self.encoder(conv_feat, input_lengths) output, decoder_self_attns, memory_attns = self.decoder(targets, input_lengths, memory) @@ -187,6 +182,10 @@ class SpeechTransformerEncoder(nn.Module): ffnet_style: style of feed forward network [ff, conv] (default: ff) dropout_p: probability of dropout (default: 0.3) pad_id: identification of pad token (default: 0) + + Inputs: + - **inputs**: list of sequences, whose length is the batch size and within which each sequence is list of tokens + - **input_lengths**: list of sequence lengths """ def __init__( @@ -213,21 +212,14 @@ def __init__( [SpeechTransformerEncoderLayer(d_model, num_heads, d_ff, dropout_p, ffnet_style) for _ in range(num_layers)] ) - def forward(self, inputs: Tensor, input_lengths: Tensor = None): - """ - Args: - inputs: BxT_inputxD - input_lengths: Bx1 - """ + def forward(self, inputs: Tensor, input_lengths: Tensor = None) -> Tuple[Tensor, list]: self_attns = list() - - non_pad_mask = get_pad_mask(inputs, input_lengths=input_lengths).eq(False) self_attn_mask = get_attn_pad_mask(inputs, input_lengths, inputs.size(1)) output = self.input_dropout(self.input_norm(self.input_proj(inputs)) + self.positional_encoding(inputs.size(1))) for layer in self.layers: - output, attn = layer(output, non_pad_mask, self_attn_mask) + output, attn = layer(output, self_attn_mask) self_attns.append(attn) return output, self_attns @@ -253,15 +245,15 @@ class SpeechTransformerDecoder(nn.Module): def __init__( self, - num_classes: int, # number of classes - d_model: int = 512, # dimension of model - d_ff: int = 512, # dimension of feed forward network - num_layers: int = 6, # number of decoder layers - num_heads: int = 8, # number of attention heads - ffnet_style: str = 'ff', # style of feed forward network - dropout_p: float = 0.3, # probability of dropout - pad_id: int = 0, # identification of pad token - eos_id: int = 2 # identification of end of sentence token + num_classes: int, # number of classes + d_model: int = 512, # dimension of model + d_ff: int = 512, # dimension of feed forward network + num_layers: int = 6, # number of decoder layers + num_heads: int = 8, # number of attention heads + ffnet_style: str = 'ff', # style of feed forward network + dropout_p: float = 0.3, # probability of dropout + pad_id: int = 0, # identification of pad token + eos_id: int = 2 # identification of end of sentence token ) -> None: super(SpeechTransformerDecoder, self).__init__() self.d_model = d_model @@ -280,14 +272,13 @@ def forward(self, inputs: Tensor, input_lengths: Optional[Any] = None, memory: T self_attns, memory_attns = list(), list() batch_size, output_length = inputs.size(0), inputs.size(1) - non_pad_mask = get_pad_mask(inputs, pad_id=self.pad_id).eq(False) self_attn_mask = get_decoder_self_attn_mask(inputs, inputs, self.pad_id) memory_mask = get_attn_pad_mask(memory, input_lengths, output_length) output = self.input_dropout(self.embedding(inputs) + self.positional_encoding(inputs.size(1))) for layer in self.layers: - output, self_attn, memory_attn = layer(output, memory, non_pad_mask, self_attn_mask, memory_mask) + output, self_attn, memory_attn = layer(output, memory, self_attn_mask, memory_mask) self_attns.append(self_attn) memory_attns.append(memory_attn)