-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add MHSA module, Conformer block and encoder with relative PE (#55)
Co-authored-by: Albert Zeyer <[email protected]>
- Loading branch information
1 parent
1972683
commit 9c0fe3f
Showing
11 changed files
with
842 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .conformer_v1 import * | ||
from .conformer_v2 import * | ||
from .conformer_rel_pos_v1 import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
from __future__ import annotations | ||
|
||
__all__ = [ | ||
"ConformerRelPosBlockV1Config", | ||
"ConformerRelPosEncoderV1Config", | ||
"ConformerRelPosBlockV1", | ||
"ConformerRelPosEncoderV1", | ||
] | ||
|
||
import torch | ||
from torch import nn | ||
from dataclasses import dataclass, field | ||
from typing import List | ||
|
||
from i6_models.config import ModelConfiguration, ModuleFactoryV1 | ||
from i6_models.parts.conformer import ( | ||
ConformerConvolutionV2, | ||
ConformerConvolutionV2Config, | ||
ConformerMHSARelPosV1, | ||
ConformerMHSARelPosV1Config, | ||
ConformerPositionwiseFeedForwardV2, | ||
ConformerPositionwiseFeedForwardV2Config, | ||
) | ||
from i6_models.assemblies.conformer import ConformerEncoderV2 | ||
|
||
|
||
@dataclass | ||
class ConformerRelPosBlockV1Config(ModelConfiguration): | ||
""" | ||
Attributes: | ||
ff_cfg: Configuration for ConformerPositionwiseFeedForwardV2 | ||
mhsa_cfg: Configuration for ConformerMHSARelPosV1 | ||
conv_cfg: Configuration for ConformerConvolutionV2 | ||
modules: List of modules to use for ConformerRelPosBlockV1, | ||
"ff" for feed forward module, "mhsa" for multi-head self attention module, "conv" for conv module | ||
scales: List of scales to apply to the module outputs before the residual connection | ||
""" | ||
|
||
# nested configurations | ||
ff_cfg: ConformerPositionwiseFeedForwardV2Config | ||
mhsa_cfg: ConformerMHSARelPosV1Config | ||
conv_cfg: ConformerConvolutionV2Config | ||
modules: List[str] = field(default_factory=lambda: ["ff", "mhsa", "conv", "ff"]) | ||
scales: List[float] = field(default_factory=lambda: [0.5, 1.0, 1.0, 0.5]) | ||
|
||
def __post__init__(self): | ||
super().__post_init__() | ||
assert len(self.modules) == len(self.scales), "modules and scales must have same length" | ||
for module_name in self.modules: | ||
assert module_name in ["ff", "mhsa", "conv"], "module not supported" | ||
|
||
|
||
class ConformerRelPosBlockV1(nn.Module): | ||
""" | ||
Conformer block module, modifications compared to ConformerBlockV1: | ||
- uses ConfomerMHSARelPosV1 as MHSA module | ||
- enable constructing the block with self-defined module_list as ConformerBlockV2 | ||
""" | ||
|
||
def __init__(self, cfg: ConformerRelPosBlockV1Config): | ||
""" | ||
:param cfg: conformer block configuration with subunits for the different conformer parts | ||
""" | ||
super().__init__() | ||
|
||
modules = [] | ||
for module_name in cfg.modules: | ||
if module_name == "ff": | ||
modules.append(ConformerPositionwiseFeedForwardV2(cfg=cfg.ff_cfg)) | ||
elif module_name == "mhsa": | ||
modules.append(ConformerMHSARelPosV1(cfg=cfg.mhsa_cfg)) | ||
elif module_name == "conv": | ||
modules.append(ConformerConvolutionV2(model_cfg=cfg.conv_cfg)) | ||
else: | ||
raise NotImplementedError | ||
|
||
self.module_list = nn.ModuleList(modules) | ||
self.scales = cfg.scales | ||
self.final_layer_norm = torch.nn.LayerNorm(cfg.ff_cfg.input_dim) | ||
|
||
def forward(self, x: torch.Tensor, /, sequence_mask: torch.Tensor) -> torch.Tensor: | ||
""" | ||
:param x: input tensor of shape [B, T, F] | ||
:param sequence_mask: mask tensor where 1 defines positions within the sequence and 0 outside, shape: [B, T] | ||
:return: torch.Tensor of shape [B, T, F] | ||
""" | ||
for scale, module in zip(self.scales, self.module_list): | ||
if isinstance(module, ConformerMHSARelPosV1): | ||
x = scale * module(x, sequence_mask) + x | ||
else: | ||
x = scale * module(x) + x | ||
x = self.final_layer_norm(x) # [B, T, F] | ||
return x | ||
|
||
|
||
@dataclass | ||
class ConformerRelPosEncoderV1Config(ModelConfiguration): | ||
""" | ||
Attributes: | ||
num_layers: Number of conformer layers in the conformer encoder | ||
frontend: A pair of ConformerFrontend and corresponding config | ||
block_cfg: Configuration for ConformerRelPosBlockV1 | ||
""" | ||
|
||
num_layers: int | ||
|
||
# nested configurations | ||
frontend: ModuleFactoryV1 | ||
block_cfg: ConformerRelPosBlockV1Config | ||
|
||
|
||
class ConformerRelPosEncoderV1(ConformerEncoderV2): | ||
""" | ||
Modifications compared to ConformerEncoderV2: | ||
- supports Shaw's relative positional encoding using learnable position embeddings | ||
and Transformer-XL style relative PE using fixed sinusoidal or learnable position embeddings | ||
""" | ||
|
||
def __init__(self, cfg: ConformerRelPosEncoderV1Config): | ||
""" | ||
:param cfg: conformer encoder configuration with subunits for frontend and conformer blocks | ||
""" | ||
super().__init__(cfg) | ||
|
||
self.frontend = cfg.frontend() | ||
self.module_list = torch.nn.ModuleList([ConformerRelPosBlockV1(cfg.block_cfg) for _ in range(cfg.num_layers)]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from .convolution import * | ||
from .feedforward import * | ||
from .mhsa import * | ||
from .mhsa_rel_pos import * | ||
from .norm import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.