Skip to content

Commit

Permalink
Merge pull request #75 from sooftware/transformer-debug
Browse files Browse the repository at this point in the history
resolved issue #37
  • Loading branch information
sooftware authored Jan 4, 2021
2 parents 6e70914 + a3edd92 commit dcd2a6a
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 99 deletions.
2 changes: 1 addition & 1 deletion kospeech/models/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class VGGExtractor(CNNExtractor):
"Advances in Joint CTC-Attention based End-to-End Speech Recognition with a Deep CNN Encoder and RNN-LM" paper
- https://arxiv.org/pdf/1706.02737.pdf
"""
def __init__(self, activation: str, mask_conv: bool):
def __init__(self, activation: str = 'hardtanh', mask_conv: bool = False):
super(VGGExtractor, self).__init__(activation)
self.mask_conv = mask_conv
self.conv = nn.Sequential(
Expand Down
18 changes: 1 addition & 17 deletions kospeech/models/transformer/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,9 @@ def __init__(
self.self_attention = AddNorm(MultiHeadAttention(d_model, num_heads), d_model)
self.feed_forward = AddNorm(PositionWiseFeedForwardNet(d_model, d_ff, dropout_p, ffnet_style), d_model)

def forward(
self,
inputs: Tensor, # B x T_input x D
non_pad_mask: Optional[Any] = None, # B x T_input
self_attn_mask: Optional[Any] = None # B x T_input x T_output
) -> Tuple[Tensor, Tensor]:
def forward(self, inputs: Tensor, self_attn_mask: Optional[Any] = None) -> Tuple[Tensor, Tensor]:
output, attn = self.self_attention(inputs, inputs, inputs, self_attn_mask)
output *= non_pad_mask

output = self.feed_forward(output)
output *= non_pad_mask

return output, attn


Expand Down Expand Up @@ -84,17 +75,10 @@ def forward(
self,
inputs: Tensor, # B x T_input
memory: Tensor, # B x T_input x D_model
non_pad_mask: Optional[Any] = None, # B x T_input
self_attn_mask: Optional[Any] = None, # B x T_input x T_input
memory_mask: Optional[Any] = None # B x T_input x T_output
) -> Tuple[Tensor, Tensor, Tensor]:
output, self_attn = self.self_attention(inputs, inputs, inputs, self_attn_mask)
output *= non_pad_mask

output, memory_attn = self.memory_attention(output, memory, memory, memory_mask)
output *= non_pad_mask

output = self.feed_forward(output)
output *= non_pad_mask

return output, self_attn, memory_attn
26 changes: 9 additions & 17 deletions kospeech/models/transformer/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,34 +9,26 @@
from typing import Any, Optional


def get_pad_mask(inputs: Tensor, input_lengths: Optional[Any] = None, pad_id: int = None) -> Tensor:
"""
Padding position is set to True, either use input_lengths or pad_id
Examples::
>>> get_pad_mask(inputs, input_lengths)
tensor([[[False], [False], [False], [False], [False], [ True], [ True], [ True], [ True]],
[[False], [False], [False], [False], [False], [False], [ True], [ True], [ True]],
[[False], [False], [False], [False], [False], [False], [False], [False], [ True]]])
"""
def get_non_pad_mask(inputs: Tensor, input_lengths: Optional[Any] = None, pad_id: int = None) -> Tensor:
""" Padding position is set to 0, either use input_lengths or pad_id """
assert (input_lengths is None and pad_id is not None) or (input_lengths is not None and pad_id is None)

if input_lengths is not None:
batch_size = inputs.size(0)

if len(inputs.size()) == 2:
pad_mask = inputs.new_zeros(inputs.size()) # B x T
non_pad_mask = inputs.new_ones(inputs.size()) # B x T
else:
pad_mask = inputs.new_zeros(inputs.size()[:-1]) # B x T
non_pad_mask = inputs.new_ones(inputs.size()[:-1]) # B x T

for i in range(batch_size):
pad_mask[i, input_lengths[i]:] = 1
non_pad_mask[i, input_lengths[i]:] = 0

if pad_id is not None:
assert inputs.dim() == 2
pad_mask = inputs.eq(pad_id)
non_pad_mask = inputs.ne(pad_id).float()

return pad_mask.unsqueeze(-1).bool()
return non_pad_mask.unsqueeze(-1)


def get_decoder_self_attn_mask(seq_k: Tensor, seq_q: Tensor, pad_id):
Expand Down Expand Up @@ -139,8 +131,8 @@ def get_attn_pad_mask(inputs, input_lengths, expand_length):
"""
# N x Ti x 1
non_pad_mask = get_pad_mask(inputs, input_lengths=input_lengths).eq(False)
non_pad_mask = get_non_pad_mask(inputs, input_lengths=input_lengths)
# N x Ti, lt(1) like not operation
pad_mask = non_pad_mask.squeeze(-1).eq(False)
pad_mask = non_pad_mask.squeeze(-1).lt(1)
attn_mask = pad_mask.unsqueeze(1).expand(-1, expand_length, -1)
return attn_mask
119 changes: 55 additions & 64 deletions kospeech/models/transformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,33 @@
import math
import torch
import torch.nn as nn

from torch import Tensor
from typing import Optional, Any
from typing import (
Optional,
Any,
Tuple,
Union,
)
from kospeech.models.extractor import (
VGGExtractor,
DeepSpeech2Extractor,
)
from kospeech.models.modules import (
Linear,
LayerNorm
LayerNorm,
)
from kospeech.models.transformer.mask import (
get_pad_mask,
get_attn_pad_mask,
get_decoder_self_attn_mask
get_decoder_self_attn_mask,
)
from kospeech.models.transformer.embeddings import (
Embedding,
PositionalEncoding
PositionalEncoding,
)
from kospeech.models.transformer.layers import (
SpeechTransformerEncoderLayer,
SpeechTransformerDecoderLayer
SpeechTransformerDecoderLayer,
)


Expand All @@ -53,9 +62,13 @@ class SpeechTransformer(nn.Module):
ffnet_style (str): if poswise_ffnet is 'ff', position-wise feed forware network to be a feed forward,
otherwise, position-wise feed forward network to be a convolution layer. (default: ff)
Inputs: inputs, targets
- **inputs** (batch, input_length): tensor containing input sequences
- **targets** (batch, target_length): tensor contatining target sequences
Inputs: inputs, input_lengths, targets, teacher_forcing_ratio
- **inputs** (torch.Tensor): tensor of sequences, whose length is the batch size and within which
each sequence is a list of token IDs. This information is forwarded to the encoder.
- **input_lengths** (torch.Tensor): tensor of sequences, whose contains length of inputs.
- **targets** (torch.Tensor): tensor of sequences, whose length is the batch size and within which
each sequence is a list of token IDs. This information is forwarded to the decoder.
- **return_attns (bool): flag indication whether to return attention lists
Returns: output
- **output**: tensor containing the outputs
Expand All @@ -80,37 +93,17 @@ def __init__(

assert d_model % num_heads == 0, "d_model % num_heads should be zero."

if extractor.lower() == 'vgg':
self.extractor = extractor

if self.extractor == 'vgg':
input_dim = (input_dim - 1) << 5 if input_dim % 2 else input_dim << 5
self.conv = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(num_features=64),
nn.Hardtanh(0, 20, inplace=True),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(num_features=64),
nn.Hardtanh(0, 20, inplace=True),
nn.MaxPool2d(2, stride=2),
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(num_features=128),
nn.Hardtanh(0, 20, inplace=True),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(num_features=128),
nn.Hardtanh(0, 20, inplace=True),
nn.MaxPool2d(2, stride=2)
)

elif extractor.lower() == 'ds2':
self.conv = VGGExtractor(mask_conv=False)

elif self.extractor == 'ds2':
input_dim = int(math.floor(input_dim + 2 * 20 - 41) / 2 + 1)
input_dim = int(math.floor(input_dim + 2 * 10 - 21) / 2 + 1)
input_dim <<= 5
self.conv = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=(41, 11), stride=(2, 2), padding=(20, 5), bias=False),
nn.BatchNorm2d(32),
nn.Hardtanh(0, 20, inplace=True),
nn.Conv2d(32, 32, kernel_size=(21, 11), stride=(2, 1), padding=(10, 5), bias=False),
nn.BatchNorm2d(32),
nn.Hardtanh(0, 20, inplace=True),
)
input_dim <<= 6
self.conv = DeepSpeech2Extractor(mask_conv=False)

else:
raise ValueError("Unsupported Extractor : {0}".format(extractor))
Expand Down Expand Up @@ -143,17 +136,19 @@ def __init__(

def forward(
self,
inputs: Tensor, # B x T_input x D_Feature
input_lengths: Tensor, # B
targets: Optional[Tensor] = None, # B x T_output => <sos> a b c d e . . . <eos> <pad> <pad> <pad>
return_attns: bool = False # bool
):
conv_feat = self.conv(inputs.unsqueeze(1))
inputs: Tensor, # tensor of input sequences
input_lengths: Tensor, # tensor of input sequence lengths
targets: Optional[Tensor] = None, # tensor of target sequences
return_attns: bool = False # flag indication whether to return attention lists
) -> Union[Tensor, tuple]:
conv_feat = self.conv(inputs.unsqueeze(1), input_lengths)
conv_feat = conv_feat.transpose(1, 2)

batch_size, seq_length, num_channels, hidden_dim = conv_feat.size()
conv_feat = conv_feat.contiguous().view(batch_size, seq_length, num_channels * hidden_dim)
input_lengths = (input_lengths >> 2).int()

if self.extractor == 'vgg':
input_lengths = (input_lengths >> 2).int()

memory, encoder_self_attns = self.encoder(conv_feat, input_lengths)
output, decoder_self_attns, memory_attns = self.decoder(targets, input_lengths, memory)
Expand Down Expand Up @@ -187,6 +182,10 @@ class SpeechTransformerEncoder(nn.Module):
ffnet_style: style of feed forward network [ff, conv] (default: ff)
dropout_p: probability of dropout (default: 0.3)
pad_id: identification of pad token (default: 0)
Inputs:
- **inputs**: list of sequences, whose length is the batch size and within which each sequence is list of tokens
- **input_lengths**: list of sequence lengths
"""

def __init__(
Expand All @@ -213,21 +212,14 @@ def __init__(
[SpeechTransformerEncoderLayer(d_model, num_heads, d_ff, dropout_p, ffnet_style) for _ in range(num_layers)]
)

def forward(self, inputs: Tensor, input_lengths: Tensor = None):
"""
Args:
inputs: BxT_inputxD
input_lengths: Bx1
"""
def forward(self, inputs: Tensor, input_lengths: Tensor = None) -> Tuple[Tensor, list]:
self_attns = list()

non_pad_mask = get_pad_mask(inputs, input_lengths=input_lengths).eq(False)
self_attn_mask = get_attn_pad_mask(inputs, input_lengths, inputs.size(1))

output = self.input_dropout(self.input_norm(self.input_proj(inputs)) + self.positional_encoding(inputs.size(1)))

for layer in self.layers:
output, attn = layer(output, non_pad_mask, self_attn_mask)
output, attn = layer(output, self_attn_mask)
self_attns.append(attn)

return output, self_attns
Expand All @@ -253,15 +245,15 @@ class SpeechTransformerDecoder(nn.Module):

def __init__(
self,
num_classes: int, # number of classes
d_model: int = 512, # dimension of model
d_ff: int = 512, # dimension of feed forward network
num_layers: int = 6, # number of decoder layers
num_heads: int = 8, # number of attention heads
ffnet_style: str = 'ff', # style of feed forward network
dropout_p: float = 0.3, # probability of dropout
pad_id: int = 0, # identification of pad token
eos_id: int = 2 # identification of end of sentence token
num_classes: int, # number of classes
d_model: int = 512, # dimension of model
d_ff: int = 512, # dimension of feed forward network
num_layers: int = 6, # number of decoder layers
num_heads: int = 8, # number of attention heads
ffnet_style: str = 'ff', # style of feed forward network
dropout_p: float = 0.3, # probability of dropout
pad_id: int = 0, # identification of pad token
eos_id: int = 2 # identification of end of sentence token
) -> None:
super(SpeechTransformerDecoder, self).__init__()
self.d_model = d_model
Expand All @@ -280,14 +272,13 @@ def forward(self, inputs: Tensor, input_lengths: Optional[Any] = None, memory: T
self_attns, memory_attns = list(), list()
batch_size, output_length = inputs.size(0), inputs.size(1)

non_pad_mask = get_pad_mask(inputs, pad_id=self.pad_id).eq(False)
self_attn_mask = get_decoder_self_attn_mask(inputs, inputs, self.pad_id)
memory_mask = get_attn_pad_mask(memory, input_lengths, output_length)

output = self.input_dropout(self.embedding(inputs) + self.positional_encoding(inputs.size(1)))

for layer in self.layers:
output, self_attn, memory_attn = layer(output, memory, non_pad_mask, self_attn_mask, memory_mask)
output, self_attn, memory_attn = layer(output, memory, self_attn_mask, memory_mask)
self_attns.append(self_attn)
memory_attns.append(memory_attn)

Expand Down

0 comments on commit dcd2a6a

Please sign in to comment.