Skip to content

Commit

Permalink
Add ConformerEncoderV2 (#49)
Browse files Browse the repository at this point in the history
* add Conformer V2

* update init

* update

* black

* Update i6_models/assemblies/conformer/conformer_v2.py

Co-authored-by: Albert Zeyer <[email protected]>

* Update i6_models/assemblies/conformer/conformer_v2.py

Co-authored-by: Albert Zeyer <[email protected]>

* update

* make conformer block customizable

* unify names

* fix default member values

* fixes

* fix cfg name

* fix name collision

* Update black.yml

* Update model_tests.yml

* remove unused variable

* Update i6_models/assemblies/conformer/conformer_v2.py

Co-authored-by: Albert Zeyer <[email protected]>

* update

* update

* fix import

* update

* update

* update

---------

Co-authored-by: Jingjing Xu <[email protected]>
Co-authored-by: Albert Zeyer <[email protected]>
Co-authored-by: pzheng <[email protected]>
Co-authored-by: Ping Zheng <[email protected]>
Co-authored-by: Nick Rossenbach <[email protected]>
  • Loading branch information
6 people authored May 29, 2024
1 parent 83ff39e commit 645d4e1
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 0 deletions.
1 change: 1 addition & 0 deletions i6_models/assemblies/conformer/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .conformer_v1 import *
from .conformer_v2 import *
154 changes: 154 additions & 0 deletions i6_models/assemblies/conformer/conformer_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
from __future__ import annotations

__all__ = ["ConformerBlockV2Config", "ConformerEncoderV2Config", "ConformerBlockV2", "ConformerEncoderV2"]

import torch
from torch import nn

from dataclasses import dataclass, field
from typing import Tuple, List, Optional

from i6_models.config import ModelConfiguration, ModuleFactoryV1
from .conformer_v1 import (
ConformerConvolutionV1,
ConformerConvolutionV1Config,
ConformerMHSAV1,
ConformerMHSAV1Config,
ConformerPositionwiseFeedForwardV1,
ConformerPositionwiseFeedForwardV1Config,
)


@dataclass
class ConformerBlockV2Config(ModelConfiguration):
"""
Attributes:
ff_cfg: Configuration for ConformerPositionwiseFeedForwardV1
mhsa_cfg: Configuration for ConformerMHSAV1
conv_cfg: Configuration for ConformerConvolutionV1
modules: List of modules to use for ConformerBlockV2,
"ff" for feed forward module, "mhsa" for multi-head self attention module, "conv" for conv module
scales: List of scales to apply to the module outputs before the residual connection
"""

# nested configurations
ff_cfg: ConformerPositionwiseFeedForwardV1Config
mhsa_cfg: ConformerMHSAV1Config
conv_cfg: ConformerConvolutionV1Config
modules: List[str] = field(default_factory=lambda: ["ff", "mhsa", "conv", "ff"])
scales: List[float] = field(default_factory=lambda: [0.5, 1.0, 1.0, 0.5])

def __post__init__(self):
super().__post_init__()
assert len(self.modules) == len(self.scales), "modules and scales must have same length"
for module_name in self.modules:
assert module_name in ["ff", "mhsa", "conv"], "module not supported"


class ConformerBlockV2(nn.Module):
"""
Modifications compared to ConformerBlockV1:
- more generic, enable constructing the block with self-defined module_list.
"""

def __init__(self, cfg: ConformerBlockV2Config):
"""
:param cfg: conformer block configuration with subunits for the different conformer parts
"""
super().__init__()

modules = []
for module_name in cfg.modules:
if module_name == "ff":
modules.append(ConformerPositionwiseFeedForwardV1(cfg=cfg.ff_cfg))
elif module_name == "mhsa":
modules.append(ConformerMHSAV1(cfg=cfg.mhsa_cfg))
elif module_name == "conv":
modules.append(ConformerConvolutionV1(model_cfg=cfg.conv_cfg))
else:
raise NotImplementedError

self.module_list = nn.ModuleList(modules)
self.scales = cfg.scales
self.final_layer_norm = torch.nn.LayerNorm(cfg.ff_cfg.input_dim)

def forward(self, x: torch.Tensor, /, sequence_mask: torch.Tensor) -> torch.Tensor:
"""
:param x: input tensor of shape [B, T, F]
:param sequence_mask: mask tensor where 0 defines positions within the sequence and 1 outside, shape: [B, T]
:return: torch.Tensor of shape [B, T, F]
"""
for scale, module in zip(self.scales, self.module_list):
if isinstance(module, ConformerMHSAV1):
x = scale * module(x, sequence_mask) + x
else:
x = scale * module(x) + x

x = self.final_layer_norm(x) # [B, T, F]
return x


@dataclass
class ConformerEncoderV2Config(ModelConfiguration):
"""
Attributes:
num_layers: Number of conformer layers in the conformer encoder
frontend: A pair of ConformerFrontend and corresponding config
block_cfg: Configuration for ConformerBlockV2
"""

num_layers: int

# nested configurations
frontend: ModuleFactoryV1
block_cfg: ConformerBlockV2Config


class ConformerEncoderV2(nn.Module):
"""
Modifications compared to ConformerEncoderV1:
- modules inside each conformer block could be customized.
- return a list of layer outputs with desired layer indices
"""

def __init__(self, cfg: ConformerEncoderV2Config):
"""
:param cfg: conformer encoder configuration with subunits for frontend and conformer blocks
"""
super().__init__()

self.frontend = cfg.frontend()
self.module_list = torch.nn.ModuleList([ConformerBlockV2(cfg.block_cfg) for _ in range(cfg.num_layers)])

def forward(
self, data_tensor: torch.Tensor, /, sequence_mask: torch.Tensor, return_layers: Optional[List[int]] = None
) -> Tuple[List[torch.Tensor], torch.Tensor]:
"""
:param data_tensor: input tensor of shape [B, T', F]
:param sequence_mask: mask tensor where 1 defines positions within the sequence and 0 outside, shape: [B, T']
:param return_layers: list of layer indices specifying which layers to return, starting from 0
:return: (outputs, out_seq_mask)
where outputs is a list of torch.Tensor of shape [B, T, F']
for each of the layers in return_layers,
out_seq_mask is a torch.Tensor of shape [B, T]
F: input feature dim, F': internal and output feature dim
T': data time dim, T: down-sampled time dim (internal time dim)
"""

if return_layers is None:
return_layers = [len(self.module_list) - 1]

x, sequence_mask = self.frontend(data_tensor, sequence_mask) # [B, T, F']

outputs = []
assert (
max(return_layers) < len(self.module_list) and min(return_layers) >= 0
), f"invalid layer index, should be between 0 and {len(self.module_list)-1}"

for i in range(max(return_layers) + 1):
x = self.module_list[i](x, sequence_mask) # [B, T, F']
if i in return_layers:
outputs.append(x)

return outputs, sequence_mask

0 comments on commit 645d4e1

Please sign in to comment.