Skip to content

Commit

Permalink
Add diphone factored hybrid model backend
Browse files Browse the repository at this point in the history
  • Loading branch information
NeoLegends committed Jun 27, 2024
1 parent 56bf9fa commit 0496d33
Show file tree
Hide file tree
Showing 4 changed files with 308 additions and 0 deletions.
10 changes: 10 additions & 0 deletions i6_models/parts/factored_hybrid/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
__all__ = [
"DiphoneLogitsV1",
"DiphoneProbsV1",
"DiphoneBackendV1Config",
"DiphoneBackendV1",
"PhonemeStateClassV1",
]

from .diphone import *
from .util import PhonemeStateClassV1
167 changes: 167 additions & 0 deletions i6_models/parts/factored_hybrid/diphone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
__all__ = [
"DiphoneLogitsV1",
"DiphoneProbsV1",
"DiphoneBackendV1Config",
"DiphoneBackendV1",
]

from dataclasses import dataclass
from typing import Callable, Optional, Union

import torch
from torch import nn, Tensor
import torch.nn.functional as F

from i6_models.config import ModelConfiguration

from .util import PhonemeStateClassV1, get_center_dim, get_mlp


@dataclass
class DiphoneLogitsV1:
"""outputs of a diphone factored hybrid model"""

embeddings_left_context: Tensor
"""the embedded left context values"""

output_center: Tensor
"""center output"""

output_left: Tensor
"""left output"""


class DiphoneProbsV1(DiphoneLogitsV1):
"""marker class indicating that the output tensors store probabilities and not logits"""


@dataclass
class DiphoneBackendV1Config(ModelConfiguration):
"""
Attributes:
context_mix_mlp_dim: inner dimension of the context mixing MLP layers
context_mix_mlp_num_layers: how many hidden layers on the MLPs there should be
dropout: dropout probabilty
left_context_embedding_dim: embedding dimension of the left context
values. Good choice is in the order of n_contexts.
n_contexts: the number of raw phonemes/acoustic contexts
num_hmm_states_per_phone: the number of HMM states per phoneme
num_inputs: input dimension of the backend, must match w/ output dimension
of main encoder (e.g. Conformer)
phoneme_state_class: the phoneme state augmentation to apply
activation: activation function to use in the context mixing MLP.
"""

left_context_embedding_dim: int
n_contexts: int
num_hmm_states_per_phone: int
phoneme_state_class: Union[int, PhonemeStateClassV1]

activation: Callable[[], nn.Module]
context_mix_mlp_dim: int
context_mix_mlp_num_layers: int
dropout: float
num_inputs: int

def __post_init__(self) -> None:
super().__post_init__()

assert self.context_mix_mlp_dim > 0
assert self.context_mix_mlp_num_layers > 0
assert self.n_contexts > 0
assert self.num_hmm_states_per_phone > 0
assert self.num_inputs > 0
assert self.left_context_embedding_dim > 0
assert 0.0 <= self.dropout <= 1.0, "dropout must be a probability"


class DiphoneBackendV1(nn.Module):
"""
Diphone FH model backend.
Consumes the output h(x) of a main encoder model and computes factored output
logits/probabilities for p(c|l,x) and p(l|x).
"""

def __init__(self, cfg: DiphoneBackendV1Config):
super().__init__()

self.n_contexts = cfg.n_contexts

self.left_context_encoder = get_mlp(
num_input=cfg.num_inputs,
num_output=cfg.n_contexts,
hidden_dim=cfg.context_mix_mlp_dim,
num_layers=cfg.context_mix_mlp_num_layers,
dropout=cfg.dropout,
activation=cfg.activation,
)
self.left_context_embedding = nn.Embedding(cfg.n_contexts, cfg.left_context_embedding_dim)
self.center_encoder = get_mlp(
num_input=cfg.num_inputs + cfg.left_context_embedding_dim,
num_output=get_center_dim(cfg.n_contexts, cfg.num_hmm_states_per_phone, cfg.phoneme_state_class),
hidden_dim=cfg.context_mix_mlp_dim,
num_layers=cfg.context_mix_mlp_num_layers,
dropout=cfg.dropout,
activation=cfg.activation,
)

def forward(
self,
features: Tensor, # B, T, F
contexts_left: Optional[Tensor] = None, # B, T
) -> Union[DiphoneLogitsV1, DiphoneProbsV1]:
"""
:param features: Main encoder output. shape B, T, F. F=num_inputs
:param contexts_left: The left contexts used to compute p(c|l,x).
Shape during training: B, T. Must be `None` when forwarding as the model
will compute a flat, joint output scoring all possible contexts in forwarding
mode.
:return: During training: logits for p(c|l,x) and p(l|x).
During inference: *log* probs for p(c,l|x) in one large layer.
"""

left_logits = self.left_context_encoder(features) # B, T, C

if self.training:
assert contexts_left is not None
assert contexts_left.ndim >= 2
else:
assert contexts_left is None, "in eval mode, all left contexts are forwarded at the same time"
contexts_left = torch.arange(self.n_contexts)

# train: B, T, E
# eval: C, E
embedded_left_contexts = self.left_context_embedding(contexts_left)

if self.training:
# in training we forward exactly one context per T

center_features = torch.cat((features, embedded_left_contexts), -1) # B, T, F+E
center_logits = self.center_encoder(center_features) # B, T, C

return DiphoneLogitsV1(
embeddings_left_context=embedded_left_contexts, output_left=left_logits, output_center=center_logits
)
else:
# here we forward every context to compute p(c, l|x) = p(c|l, x) * p(l|x)

features = features.expand((self.n_contexts, -1, -1, -1)) # C, B, T, F
embedded_left_contexts_ = embedded_left_contexts.reshape((self.n_contexts, 1, 1, -1)).expand(
(-1, *features.shape[1:3], -1)
) # C, B, T, E
center_features = torch.cat((features, embedded_left_contexts_), -1) # C, B, T, F+E
center_logits = self.center_encoder(center_features) # C, B, T, F'
center_probs = F.log_softmax(center_logits, -1)
center_probs = center_probs.permute((1, 2, 3, 0)) # B, T, F', C
left_probs = F.log_softmax(left_logits, -1)
left_probs = left_probs.unsqueeze(-2) # B, T, 1, C

joint_probs = center_probs + left_probs # B, T, F', C
joint_probs = torch.flatten(joint_probs, start_dim=2) # B, T, joint

return DiphoneProbsV1(
embeddings_left_context=embedded_left_contexts,
output_left=left_probs.squeeze(),
output_center=joint_probs,
)
63 changes: 63 additions & 0 deletions i6_models/parts/factored_hybrid/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from enum import Enum
from typing import Callable, Union

from torch import nn


class PhonemeStateClassV1(Enum):
"""Phoneme state class augmentation selector"""

none = 1
word_end = 2
boundary = 4

def factor(self):
return self.value


def get_center_dim(
n_contexts: int,
num_hmm_states_per_phone: int,
ph_class: Union[int, PhonemeStateClassV1],
) -> int:
"""
:return: number of center phonemes given the augmentation values
"""

factor = ph_class.factor() if isinstance(ph_class, PhonemeStateClassV1) else ph_class
return n_contexts * num_hmm_states_per_phone * factor


def get_mlp(
num_input: int,
num_output: int,
hidden_dim: int,
dropout: float,
activation: Callable[[], nn.Module],
num_layers,
) -> nn.Module:
"""
:return: a context-mixing MLP according to the specifications
"""

assert num_input > 0
assert num_output > 0
assert num_layers > 0
assert hidden_dim > 0
assert 0.0 <= dropout <= 1.0

return nn.Sequential(
*[
layer
for in_dim in [
num_input,
*[hidden_dim for _ in range(num_layers - 1)],
]
for layer in [
nn.Linear(in_dim, hidden_dim),
nn.Dropout(dropout),
activation(),
]
],
nn.Linear(hidden_dim, num_output, bias=True),
)
68 changes: 68 additions & 0 deletions tests/test_fh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from itertools import product

import torch
import torch.nn as nn

from i6_models.parts.factored_hybrid import (
DiphoneBackendV1,
DiphoneBackendV1Config,
PhonemeStateClassV1,
DiphoneLogitsV1,
DiphoneProbsV1,
)
from i6_models.parts.factored_hybrid.util import get_center_dim


def test_dim_calcs():
n_ctx = 42

assert get_center_dim(n_ctx, 1, PhonemeStateClassV1.none) == 42
assert get_center_dim(n_ctx, 1, PhonemeStateClassV1.word_end) == 84
assert get_center_dim(n_ctx, 3, PhonemeStateClassV1.word_end) == 252
assert get_center_dim(n_ctx, 3, PhonemeStateClassV1.boundary) == 504


def test_output_shape():
n_ctx = 42
n_in = 32

for we_class, states_per_ph in product(
[PhonemeStateClassV1.none, PhonemeStateClassV1.word_end, PhonemeStateClassV1.boundary],
[1, 3],
):
backend = DiphoneBackendV1(
DiphoneBackendV1Config(
activation=lambda: nn.ReLU(),
context_mix_mlp_dim=64,
context_mix_mlp_num_layers=2,
dropout=0.1,
left_context_embedding_dim=32,
n_contexts=n_ctx,
num_hmm_states_per_phone=states_per_ph,
num_inputs=n_in,
phoneme_state_class=we_class,
)
)

backend.train(True)
for b, t in product([10, 50, 100], [10, 50, 100]):
contexts_forward = torch.randint(0, n_ctx, (b, t))
encoder_output = torch.rand((b, t, n_in))
output = backend(features=encoder_output, contexts_left=contexts_forward)
assert isinstance(output, DiphoneLogitsV1) and not isinstance(output, DiphoneProbsV1)
assert output.output_left.shape == (b, t, n_ctx)
cdim = get_center_dim(n_ctx, states_per_ph, we_class)
assert output.output_center.shape == (b, t, cdim)

backend.train(False)
for b, t in product([10, 50, 100], [10, 50, 100]):
encoder_output = torch.rand((b, t, n_in))
output = backend(features=encoder_output)
assert isinstance(output, DiphoneProbsV1)
assert output.output_left.shape == (b, t, n_ctx)
cdim = get_center_dim(n_ctx, states_per_ph, we_class)
assert output.output_center.shape == (b, t, cdim * n_ctx)
output_p = torch.exp(output.output_center)
ones_hopefully = torch.sum(output_p, dim=-1)
close_to_one = torch.abs(1 - ones_hopefully).flatten() < 1e-3
assert all(close_to_one)

0 comments on commit 0496d33

Please sign in to comment.