-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add diphone factored hybrid model backend
- Loading branch information
1 parent
56bf9fa
commit 0496d33
Showing
4 changed files
with
308 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
__all__ = [ | ||
"DiphoneLogitsV1", | ||
"DiphoneProbsV1", | ||
"DiphoneBackendV1Config", | ||
"DiphoneBackendV1", | ||
"PhonemeStateClassV1", | ||
] | ||
|
||
from .diphone import * | ||
from .util import PhonemeStateClassV1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
__all__ = [ | ||
"DiphoneLogitsV1", | ||
"DiphoneProbsV1", | ||
"DiphoneBackendV1Config", | ||
"DiphoneBackendV1", | ||
] | ||
|
||
from dataclasses import dataclass | ||
from typing import Callable, Optional, Union | ||
|
||
import torch | ||
from torch import nn, Tensor | ||
import torch.nn.functional as F | ||
|
||
from i6_models.config import ModelConfiguration | ||
|
||
from .util import PhonemeStateClassV1, get_center_dim, get_mlp | ||
|
||
|
||
@dataclass | ||
class DiphoneLogitsV1: | ||
"""outputs of a diphone factored hybrid model""" | ||
|
||
embeddings_left_context: Tensor | ||
"""the embedded left context values""" | ||
|
||
output_center: Tensor | ||
"""center output""" | ||
|
||
output_left: Tensor | ||
"""left output""" | ||
|
||
|
||
class DiphoneProbsV1(DiphoneLogitsV1): | ||
"""marker class indicating that the output tensors store probabilities and not logits""" | ||
|
||
|
||
@dataclass | ||
class DiphoneBackendV1Config(ModelConfiguration): | ||
""" | ||
Attributes: | ||
context_mix_mlp_dim: inner dimension of the context mixing MLP layers | ||
context_mix_mlp_num_layers: how many hidden layers on the MLPs there should be | ||
dropout: dropout probabilty | ||
left_context_embedding_dim: embedding dimension of the left context | ||
values. Good choice is in the order of n_contexts. | ||
n_contexts: the number of raw phonemes/acoustic contexts | ||
num_hmm_states_per_phone: the number of HMM states per phoneme | ||
num_inputs: input dimension of the backend, must match w/ output dimension | ||
of main encoder (e.g. Conformer) | ||
phoneme_state_class: the phoneme state augmentation to apply | ||
activation: activation function to use in the context mixing MLP. | ||
""" | ||
|
||
left_context_embedding_dim: int | ||
n_contexts: int | ||
num_hmm_states_per_phone: int | ||
phoneme_state_class: Union[int, PhonemeStateClassV1] | ||
|
||
activation: Callable[[], nn.Module] | ||
context_mix_mlp_dim: int | ||
context_mix_mlp_num_layers: int | ||
dropout: float | ||
num_inputs: int | ||
|
||
def __post_init__(self) -> None: | ||
super().__post_init__() | ||
|
||
assert self.context_mix_mlp_dim > 0 | ||
assert self.context_mix_mlp_num_layers > 0 | ||
assert self.n_contexts > 0 | ||
assert self.num_hmm_states_per_phone > 0 | ||
assert self.num_inputs > 0 | ||
assert self.left_context_embedding_dim > 0 | ||
assert 0.0 <= self.dropout <= 1.0, "dropout must be a probability" | ||
|
||
|
||
class DiphoneBackendV1(nn.Module): | ||
""" | ||
Diphone FH model backend. | ||
Consumes the output h(x) of a main encoder model and computes factored output | ||
logits/probabilities for p(c|l,x) and p(l|x). | ||
""" | ||
|
||
def __init__(self, cfg: DiphoneBackendV1Config): | ||
super().__init__() | ||
|
||
self.n_contexts = cfg.n_contexts | ||
|
||
self.left_context_encoder = get_mlp( | ||
num_input=cfg.num_inputs, | ||
num_output=cfg.n_contexts, | ||
hidden_dim=cfg.context_mix_mlp_dim, | ||
num_layers=cfg.context_mix_mlp_num_layers, | ||
dropout=cfg.dropout, | ||
activation=cfg.activation, | ||
) | ||
self.left_context_embedding = nn.Embedding(cfg.n_contexts, cfg.left_context_embedding_dim) | ||
self.center_encoder = get_mlp( | ||
num_input=cfg.num_inputs + cfg.left_context_embedding_dim, | ||
num_output=get_center_dim(cfg.n_contexts, cfg.num_hmm_states_per_phone, cfg.phoneme_state_class), | ||
hidden_dim=cfg.context_mix_mlp_dim, | ||
num_layers=cfg.context_mix_mlp_num_layers, | ||
dropout=cfg.dropout, | ||
activation=cfg.activation, | ||
) | ||
|
||
def forward( | ||
self, | ||
features: Tensor, # B, T, F | ||
contexts_left: Optional[Tensor] = None, # B, T | ||
) -> Union[DiphoneLogitsV1, DiphoneProbsV1]: | ||
""" | ||
:param features: Main encoder output. shape B, T, F. F=num_inputs | ||
:param contexts_left: The left contexts used to compute p(c|l,x). | ||
Shape during training: B, T. Must be `None` when forwarding as the model | ||
will compute a flat, joint output scoring all possible contexts in forwarding | ||
mode. | ||
:return: During training: logits for p(c|l,x) and p(l|x). | ||
During inference: *log* probs for p(c,l|x) in one large layer. | ||
""" | ||
|
||
left_logits = self.left_context_encoder(features) # B, T, C | ||
|
||
if self.training: | ||
assert contexts_left is not None | ||
assert contexts_left.ndim >= 2 | ||
else: | ||
assert contexts_left is None, "in eval mode, all left contexts are forwarded at the same time" | ||
contexts_left = torch.arange(self.n_contexts) | ||
|
||
# train: B, T, E | ||
# eval: C, E | ||
embedded_left_contexts = self.left_context_embedding(contexts_left) | ||
|
||
if self.training: | ||
# in training we forward exactly one context per T | ||
|
||
center_features = torch.cat((features, embedded_left_contexts), -1) # B, T, F+E | ||
center_logits = self.center_encoder(center_features) # B, T, C | ||
|
||
return DiphoneLogitsV1( | ||
embeddings_left_context=embedded_left_contexts, output_left=left_logits, output_center=center_logits | ||
) | ||
else: | ||
# here we forward every context to compute p(c, l|x) = p(c|l, x) * p(l|x) | ||
|
||
features = features.expand((self.n_contexts, -1, -1, -1)) # C, B, T, F | ||
embedded_left_contexts_ = embedded_left_contexts.reshape((self.n_contexts, 1, 1, -1)).expand( | ||
(-1, *features.shape[1:3], -1) | ||
) # C, B, T, E | ||
center_features = torch.cat((features, embedded_left_contexts_), -1) # C, B, T, F+E | ||
center_logits = self.center_encoder(center_features) # C, B, T, F' | ||
center_probs = F.log_softmax(center_logits, -1) | ||
center_probs = center_probs.permute((1, 2, 3, 0)) # B, T, F', C | ||
left_probs = F.log_softmax(left_logits, -1) | ||
left_probs = left_probs.unsqueeze(-2) # B, T, 1, C | ||
|
||
joint_probs = center_probs + left_probs # B, T, F', C | ||
joint_probs = torch.flatten(joint_probs, start_dim=2) # B, T, joint | ||
|
||
return DiphoneProbsV1( | ||
embeddings_left_context=embedded_left_contexts, | ||
output_left=left_probs.squeeze(), | ||
output_center=joint_probs, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from enum import Enum | ||
from typing import Callable, Union | ||
|
||
from torch import nn | ||
|
||
|
||
class PhonemeStateClassV1(Enum): | ||
"""Phoneme state class augmentation selector""" | ||
|
||
none = 1 | ||
word_end = 2 | ||
boundary = 4 | ||
|
||
def factor(self): | ||
return self.value | ||
|
||
|
||
def get_center_dim( | ||
n_contexts: int, | ||
num_hmm_states_per_phone: int, | ||
ph_class: Union[int, PhonemeStateClassV1], | ||
) -> int: | ||
""" | ||
:return: number of center phonemes given the augmentation values | ||
""" | ||
|
||
factor = ph_class.factor() if isinstance(ph_class, PhonemeStateClassV1) else ph_class | ||
return n_contexts * num_hmm_states_per_phone * factor | ||
|
||
|
||
def get_mlp( | ||
num_input: int, | ||
num_output: int, | ||
hidden_dim: int, | ||
dropout: float, | ||
activation: Callable[[], nn.Module], | ||
num_layers, | ||
) -> nn.Module: | ||
""" | ||
:return: a context-mixing MLP according to the specifications | ||
""" | ||
|
||
assert num_input > 0 | ||
assert num_output > 0 | ||
assert num_layers > 0 | ||
assert hidden_dim > 0 | ||
assert 0.0 <= dropout <= 1.0 | ||
|
||
return nn.Sequential( | ||
*[ | ||
layer | ||
for in_dim in [ | ||
num_input, | ||
*[hidden_dim for _ in range(num_layers - 1)], | ||
] | ||
for layer in [ | ||
nn.Linear(in_dim, hidden_dim), | ||
nn.Dropout(dropout), | ||
activation(), | ||
] | ||
], | ||
nn.Linear(hidden_dim, num_output, bias=True), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from itertools import product | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from i6_models.parts.factored_hybrid import ( | ||
DiphoneBackendV1, | ||
DiphoneBackendV1Config, | ||
PhonemeStateClassV1, | ||
DiphoneLogitsV1, | ||
DiphoneProbsV1, | ||
) | ||
from i6_models.parts.factored_hybrid.util import get_center_dim | ||
|
||
|
||
def test_dim_calcs(): | ||
n_ctx = 42 | ||
|
||
assert get_center_dim(n_ctx, 1, PhonemeStateClassV1.none) == 42 | ||
assert get_center_dim(n_ctx, 1, PhonemeStateClassV1.word_end) == 84 | ||
assert get_center_dim(n_ctx, 3, PhonemeStateClassV1.word_end) == 252 | ||
assert get_center_dim(n_ctx, 3, PhonemeStateClassV1.boundary) == 504 | ||
|
||
|
||
def test_output_shape(): | ||
n_ctx = 42 | ||
n_in = 32 | ||
|
||
for we_class, states_per_ph in product( | ||
[PhonemeStateClassV1.none, PhonemeStateClassV1.word_end, PhonemeStateClassV1.boundary], | ||
[1, 3], | ||
): | ||
backend = DiphoneBackendV1( | ||
DiphoneBackendV1Config( | ||
activation=lambda: nn.ReLU(), | ||
context_mix_mlp_dim=64, | ||
context_mix_mlp_num_layers=2, | ||
dropout=0.1, | ||
left_context_embedding_dim=32, | ||
n_contexts=n_ctx, | ||
num_hmm_states_per_phone=states_per_ph, | ||
num_inputs=n_in, | ||
phoneme_state_class=we_class, | ||
) | ||
) | ||
|
||
backend.train(True) | ||
for b, t in product([10, 50, 100], [10, 50, 100]): | ||
contexts_forward = torch.randint(0, n_ctx, (b, t)) | ||
encoder_output = torch.rand((b, t, n_in)) | ||
output = backend(features=encoder_output, contexts_left=contexts_forward) | ||
assert isinstance(output, DiphoneLogitsV1) and not isinstance(output, DiphoneProbsV1) | ||
assert output.output_left.shape == (b, t, n_ctx) | ||
cdim = get_center_dim(n_ctx, states_per_ph, we_class) | ||
assert output.output_center.shape == (b, t, cdim) | ||
|
||
backend.train(False) | ||
for b, t in product([10, 50, 100], [10, 50, 100]): | ||
encoder_output = torch.rand((b, t, n_in)) | ||
output = backend(features=encoder_output) | ||
assert isinstance(output, DiphoneProbsV1) | ||
assert output.output_left.shape == (b, t, n_ctx) | ||
cdim = get_center_dim(n_ctx, states_per_ph, we_class) | ||
assert output.output_center.shape == (b, t, cdim * n_ctx) | ||
output_p = torch.exp(output.output_center) | ||
ones_hopefully = torch.sum(output_p, dim=-1) | ||
close_to_one = torch.abs(1 - ones_hopefully).flatten() < 1e-3 | ||
assert all(close_to_one) |