Skip to content

Commit

Permalink
Add MHSA module, Conformer block and encoder with relative PE (#55)
Browse files Browse the repository at this point in the history
Co-authored-by: Albert Zeyer <[email protected]>
  • Loading branch information
kuacakuaca and albertz authored Sep 12, 2024
1 parent 1972683 commit 9c0fe3f
Show file tree
Hide file tree
Showing 11 changed files with 842 additions and 7 deletions.
1 change: 1 addition & 0 deletions i6_models/assemblies/conformer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .conformer_v1 import *
from .conformer_v2 import *
from .conformer_rel_pos_v1 import *
126 changes: 126 additions & 0 deletions i6_models/assemblies/conformer/conformer_rel_pos_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from __future__ import annotations

__all__ = [
"ConformerRelPosBlockV1Config",
"ConformerRelPosEncoderV1Config",
"ConformerRelPosBlockV1",
"ConformerRelPosEncoderV1",
]

import torch
from torch import nn
from dataclasses import dataclass, field
from typing import List

from i6_models.config import ModelConfiguration, ModuleFactoryV1
from i6_models.parts.conformer import (
ConformerConvolutionV2,
ConformerConvolutionV2Config,
ConformerMHSARelPosV1,
ConformerMHSARelPosV1Config,
ConformerPositionwiseFeedForwardV2,
ConformerPositionwiseFeedForwardV2Config,
)
from i6_models.assemblies.conformer import ConformerEncoderV2


@dataclass
class ConformerRelPosBlockV1Config(ModelConfiguration):
"""
Attributes:
ff_cfg: Configuration for ConformerPositionwiseFeedForwardV2
mhsa_cfg: Configuration for ConformerMHSARelPosV1
conv_cfg: Configuration for ConformerConvolutionV2
modules: List of modules to use for ConformerRelPosBlockV1,
"ff" for feed forward module, "mhsa" for multi-head self attention module, "conv" for conv module
scales: List of scales to apply to the module outputs before the residual connection
"""

# nested configurations
ff_cfg: ConformerPositionwiseFeedForwardV2Config
mhsa_cfg: ConformerMHSARelPosV1Config
conv_cfg: ConformerConvolutionV2Config
modules: List[str] = field(default_factory=lambda: ["ff", "mhsa", "conv", "ff"])
scales: List[float] = field(default_factory=lambda: [0.5, 1.0, 1.0, 0.5])

def __post__init__(self):
super().__post_init__()
assert len(self.modules) == len(self.scales), "modules and scales must have same length"
for module_name in self.modules:
assert module_name in ["ff", "mhsa", "conv"], "module not supported"


class ConformerRelPosBlockV1(nn.Module):
"""
Conformer block module, modifications compared to ConformerBlockV1:
- uses ConfomerMHSARelPosV1 as MHSA module
- enable constructing the block with self-defined module_list as ConformerBlockV2
"""

def __init__(self, cfg: ConformerRelPosBlockV1Config):
"""
:param cfg: conformer block configuration with subunits for the different conformer parts
"""
super().__init__()

modules = []
for module_name in cfg.modules:
if module_name == "ff":
modules.append(ConformerPositionwiseFeedForwardV2(cfg=cfg.ff_cfg))
elif module_name == "mhsa":
modules.append(ConformerMHSARelPosV1(cfg=cfg.mhsa_cfg))
elif module_name == "conv":
modules.append(ConformerConvolutionV2(model_cfg=cfg.conv_cfg))
else:
raise NotImplementedError

self.module_list = nn.ModuleList(modules)
self.scales = cfg.scales
self.final_layer_norm = torch.nn.LayerNorm(cfg.ff_cfg.input_dim)

def forward(self, x: torch.Tensor, /, sequence_mask: torch.Tensor) -> torch.Tensor:
"""
:param x: input tensor of shape [B, T, F]
:param sequence_mask: mask tensor where 1 defines positions within the sequence and 0 outside, shape: [B, T]
:return: torch.Tensor of shape [B, T, F]
"""
for scale, module in zip(self.scales, self.module_list):
if isinstance(module, ConformerMHSARelPosV1):
x = scale * module(x, sequence_mask) + x
else:
x = scale * module(x) + x
x = self.final_layer_norm(x) # [B, T, F]
return x


@dataclass
class ConformerRelPosEncoderV1Config(ModelConfiguration):
"""
Attributes:
num_layers: Number of conformer layers in the conformer encoder
frontend: A pair of ConformerFrontend and corresponding config
block_cfg: Configuration for ConformerRelPosBlockV1
"""

num_layers: int

# nested configurations
frontend: ModuleFactoryV1
block_cfg: ConformerRelPosBlockV1Config


class ConformerRelPosEncoderV1(ConformerEncoderV2):
"""
Modifications compared to ConformerEncoderV2:
- supports Shaw's relative positional encoding using learnable position embeddings
and Transformer-XL style relative PE using fixed sinusoidal or learnable position embeddings
"""

def __init__(self, cfg: ConformerRelPosEncoderV1Config):
"""
:param cfg: conformer encoder configuration with subunits for frontend and conformer blocks
"""
super().__init__(cfg)

self.frontend = cfg.frontend()
self.module_list = torch.nn.ModuleList([ConformerRelPosBlockV1(cfg.block_cfg) for _ in range(cfg.num_layers)])
1 change: 1 addition & 0 deletions i6_models/parts/conformer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .convolution import *
from .feedforward import *
from .mhsa import *
from .mhsa_rel_pos import *
from .norm import *
68 changes: 66 additions & 2 deletions i6_models/parts/conformer/convolution.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
from __future__ import annotations

__all__ = ["ConformerConvolutionV1", "ConformerConvolutionV1Config"]
__all__ = [
"ConformerConvolutionV1",
"ConformerConvolutionV1Config",
"ConformerConvolutionV2",
"ConformerConvolutionV2Config",
]

from dataclasses import dataclass
from copy import deepcopy
from typing import Callable, Union, Optional, Literal

import torch
from torch import nn
from i6_models.config import ModelConfiguration
from typing import Callable, Union
from i6_models.parts.dropout import BroadcastDropout


@dataclass
Expand Down Expand Up @@ -85,3 +91,61 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
tensor = self.pointwise_conv2(tensor)

return self.dropout(tensor)


@dataclass
class ConformerConvolutionV2Config(ConformerConvolutionV1Config):
"""
New attribute:
dropout_broadcast_axes: string of axes to which dropout is broadcast, e.g. "T" for broadcasting to the time axis
setting to None to disable broadcasting
"""

dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]]

def check_valid(self):
assert self.kernel_size % 2 == 1, "ConformerConvolutionV1 only supports odd kernel sizes"

assert self.dropout_broadcast_axes in [
None,
"B",
"T",
"BT",
], "invalid value, supported are None, 'B', 'T' and 'BT'"


class ConformerConvolutionV2(ConformerConvolutionV1):
"""
Augments ConformerMHSAV1 with dropout broadcasting
"""

def __init__(self, model_cfg: ConformerConvolutionV2Config):
"""
:param model_cfg: model configuration for this module
"""
super().__init__(model_cfg)

self.dropout = BroadcastDropout(model_cfg.dropout, dropout_broadcast_axes=model_cfg.dropout_broadcast_axes)

def forward(self, tensor: torch.Tensor) -> torch.Tensor:
"""
:param tensor: input tensor of shape [B,T,F]
:return: torch.Tensor of shape [B,T,F]
"""
tensor = self.layer_norm(tensor)
tensor = self.pointwise_conv1(tensor) # [B,T,2F]
tensor = nn.functional.glu(tensor, dim=-1) # [B,T,F]

# conv layers expect shape [B,F,T] so we have to transpose here
tensor = tensor.transpose(1, 2) # [B,F,T]
tensor = self.depthwise_conv(tensor)

tensor = self.norm(tensor)
tensor = tensor.transpose(1, 2) # transpose back to [B,T,F]

tensor = self.activation(tensor)
tensor = self.pointwise_conv2(tensor)

tensor = self.dropout(tensor)

return tensor
64 changes: 62 additions & 2 deletions i6_models/parts/conformer/feedforward.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
from __future__ import annotations

__all__ = ["ConformerPositionwiseFeedForwardV1", "ConformerPositionwiseFeedForwardV1Config"]
__all__ = [
"ConformerPositionwiseFeedForwardV1",
"ConformerPositionwiseFeedForwardV1Config",
"ConformerPositionwiseFeedForwardV2",
"ConformerPositionwiseFeedForwardV2Config",
]

from dataclasses import dataclass
from typing import Callable
from typing import Callable, Optional, Literal

import torch
from torch import nn

from i6_models.config import ModelConfiguration
from i6_models.parts.dropout import BroadcastDropout


@dataclass
Expand Down Expand Up @@ -53,3 +59,57 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
tensor = self.linear_out(tensor) # [B,T,F]
tensor = nn.functional.dropout(tensor, p=self.dropout, training=self.training) # [B,T,F]
return tensor


@dataclass
class ConformerPositionwiseFeedForwardV2Config(ModelConfiguration):
"""
New attribute:
dropout_broadcast_axes: string of axes to which dropout is broadcast, e.g. "T" for broadcasting to the time axis
setting to None to disable broadcasting
Default value for `activation` removed
"""

input_dim: int
hidden_dim: int
dropout: float
activation: Callable[[torch.Tensor], torch.Tensor]
dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]]

def check_valid(self):
assert self.dropout_broadcast_axes in [
None,
"B",
"T",
"BT",
], "invalid value, supported are None, 'B', 'T' and 'BT'"

def __post__init__(self):
super().__post_init__()
self.check_valid()


class ConformerPositionwiseFeedForwardV2(ConformerPositionwiseFeedForwardV1):
"""
Augments ConformerPositionwiseFeedForwardV1 with dropout broadcasting
"""

def __init__(self, cfg: ConformerPositionwiseFeedForwardV2Config):
super().__init__(cfg)

self.dropout = BroadcastDropout(cfg.dropout, dropout_broadcast_axes=cfg.dropout_broadcast_axes)

def forward(self, tensor: torch.Tensor) -> torch.Tensor:
"""
:param tensor: shape [B,T,F], F=input_dim
:return: shape [B,T,F], F=input_dim
"""
tensor = self.layer_norm(tensor)
tensor = self.linear_ff(tensor) # [B,T,F]
tensor = self.activation(tensor) # [B,T,F]

tensor = self.dropout(tensor) # [B,T,F]
tensor = self.linear_out(tensor) # [B,T,F]
tensor = self.dropout(tensor) # [B,T,F]

return tensor
59 changes: 58 additions & 1 deletion i6_models/parts/conformer/mhsa.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from __future__ import annotations

__all__ = ["ConformerMHSAV1", "ConformerMHSAV1Config"]
__all__ = ["ConformerMHSAV1", "ConformerMHSAV1Config", "ConformerMHSAV2", "ConformerMHSAV2Config"]

from dataclasses import dataclass
from typing import Optional, Literal
import torch

from i6_models.config import ModelConfiguration
from i6_models.util import compat
from i6_models.parts.dropout import BroadcastDropout


@dataclass
Expand Down Expand Up @@ -60,3 +63,57 @@ def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> to
output_tensor = torch.nn.functional.dropout(output_tensor, p=self.dropout, training=self.training) # [B,T,F]

return output_tensor


@dataclass
class ConformerMHSAV2Config(ConformerMHSAV1Config):
"""
New attribute:
dropout_broadcast_axes: string of axes to which dropout is broadcast, e.g. "T" for broadcasting to the time axis
setting to None to disable broadcasting
"""

dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]]

def check_valid(self):
assert self.dropout_broadcast_axes in [
None,
"B",
"T",
"BT",
], "invalid value, supported are None, 'B', 'T' and 'BT'"

def __post__init__(self):
super().__post_init__()
self.check_valid()


class ConformerMHSAV2(ConformerMHSAV1):
"""
Augments ConformerMHSAV1 with dropout broadcasting
"""

def __init__(self, cfg: ConformerMHSAV2Config):

super().__init__(cfg)

self.dropout = BroadcastDropout(cfg.dropout, dropout_broadcast_axes=cfg.dropout_broadcast_axes)

def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> torch.Tensor:
"""
Apply layer norm and multi-head self attention and dropout
:param input_tensor: Input to the self attention of shape (B, T, F)
:param sequence_mask: Bool mask of shape (B, T), True signals within sequence, False outside, will be inverted to match the torch.nn.MultiheadAttention module
which will be applied/added to dot product, used to mask padded key positions out
"""
inv_sequence_mask = compat.logical_not(sequence_mask)
output_tensor = self.layernorm(input_tensor) # [B,T,F]

output_tensor, _ = self.mhsa(
output_tensor, output_tensor, output_tensor, key_padding_mask=inv_sequence_mask, need_weights=False
) # [B,T,F]

output_tensor = self.dropout(output_tensor)

return output_tensor # [B,T,F]
Loading

0 comments on commit 9c0fe3f

Please sign in to comment.