diff --git a/i6_models/parts/factored_hybrid/__init__.py b/i6_models/parts/factored_hybrid/__init__.py new file mode 100644 index 00000000..3f19d8aa --- /dev/null +++ b/i6_models/parts/factored_hybrid/__init__.py @@ -0,0 +1,10 @@ +__all__ = [ + "DiphoneLogitsV1", + "DiphoneProbsV1", + "DiphoneBackendV1Config", + "DiphoneBackendV1", + "PhonemeStateClassV1", +] + +from .diphone import * +from .util import PhonemeStateClassV1 diff --git a/i6_models/parts/factored_hybrid/diphone.py b/i6_models/parts/factored_hybrid/diphone.py new file mode 100644 index 00000000..8837226f --- /dev/null +++ b/i6_models/parts/factored_hybrid/diphone.py @@ -0,0 +1,167 @@ +__all__ = [ + "DiphoneLogitsV1", + "DiphoneProbsV1", + "DiphoneBackendV1Config", + "DiphoneBackendV1", +] + +from dataclasses import dataclass +from typing import Callable, Optional, Union + +import torch +from torch import nn, Tensor +import torch.nn.functional as F + +from i6_models.config import ModelConfiguration + +from .util import PhonemeStateClassV1, get_center_dim, get_mlp + + +@dataclass +class DiphoneLogitsV1: + """outputs of a diphone factored hybrid model""" + + embeddings_left_context: Tensor + """the embedded left context values""" + + output_center: Tensor + """center output""" + + output_left: Tensor + """left output""" + + +class DiphoneProbsV1(DiphoneLogitsV1): + """marker class indicating that the output tensors store probabilities and not logits""" + + +@dataclass +class DiphoneBackendV1Config(ModelConfiguration): + """ + Attributes: + context_mix_mlp_dim: inner dimension of the context mixing MLP layers + context_mix_mlp_num_layers: how many hidden layers on the MLPs there should be + dropout: dropout probabilty + left_context_embedding_dim: embedding dimension of the left context + values. Good choice is in the order of n_contexts. + n_contexts: the number of raw phonemes/acoustic contexts + num_hmm_states_per_phone: the number of HMM states per phoneme + num_inputs: input dimension of the backend, must match w/ output dimension + of main encoder (e.g. Conformer) + phoneme_state_class: the phoneme state augmentation to apply + activation: activation function to use in the context mixing MLP. + """ + + left_context_embedding_dim: int + n_contexts: int + num_hmm_states_per_phone: int + phoneme_state_class: Union[int, PhonemeStateClassV1] + + activation: Callable[[], nn.Module] + context_mix_mlp_dim: int + context_mix_mlp_num_layers: int + dropout: float + num_inputs: int + + def __post_init__(self) -> None: + super().__post_init__() + + assert self.context_mix_mlp_dim > 0 + assert self.context_mix_mlp_num_layers > 0 + assert self.n_contexts > 0 + assert self.num_hmm_states_per_phone > 0 + assert self.num_inputs > 0 + assert self.left_context_embedding_dim > 0 + assert 0.0 <= self.dropout <= 1.0, "dropout must be a probability" + + +class DiphoneBackendV1(nn.Module): + """ + Diphone FH model backend. + + Consumes the output h(x) of a main encoder model and computes factored output + logits/probabilities for p(c|l,x) and p(l|x). + """ + + def __init__(self, cfg: DiphoneBackendV1Config): + super().__init__() + + self.n_contexts = cfg.n_contexts + + self.left_context_encoder = get_mlp( + num_input=cfg.num_inputs, + num_output=cfg.n_contexts, + hidden_dim=cfg.context_mix_mlp_dim, + num_layers=cfg.context_mix_mlp_num_layers, + dropout=cfg.dropout, + activation=cfg.activation, + ) + self.left_context_embedding = nn.Embedding(cfg.n_contexts, cfg.left_context_embedding_dim) + self.center_encoder = get_mlp( + num_input=cfg.num_inputs + cfg.left_context_embedding_dim, + num_output=get_center_dim(cfg.n_contexts, cfg.num_hmm_states_per_phone, cfg.phoneme_state_class), + hidden_dim=cfg.context_mix_mlp_dim, + num_layers=cfg.context_mix_mlp_num_layers, + dropout=cfg.dropout, + activation=cfg.activation, + ) + + def forward( + self, + features: Tensor, # B, T, F + contexts_left: Optional[Tensor] = None, # B, T + ) -> Union[DiphoneLogitsV1, DiphoneProbsV1]: + """ + :param features: Main encoder output. shape B, T, F. F=num_inputs + :param contexts_left: The left contexts used to compute p(c|l,x). + Shape during training: B, T. Must be `None` when forwarding as the model + will compute a flat, joint output scoring all possible contexts in forwarding + mode. + :return: During training: logits for p(c|l,x) and p(l|x). + During inference: *log* probs for p(c,l|x) in one large layer. + """ + + left_logits = self.left_context_encoder(features) # B, T, C + + if self.training: + assert contexts_left is not None + assert contexts_left.ndim >= 2 + else: + assert contexts_left is None, "in eval mode, all left contexts are forwarded at the same time" + contexts_left = torch.arange(self.n_contexts) + + # train: B, T, E + # eval: C, E + embedded_left_contexts = self.left_context_embedding(contexts_left) + + if self.training: + # in training we forward exactly one context per T + + center_features = torch.cat((features, embedded_left_contexts), -1) # B, T, F+E + center_logits = self.center_encoder(center_features) # B, T, C + + return DiphoneLogitsV1( + embeddings_left_context=embedded_left_contexts, output_left=left_logits, output_center=center_logits + ) + else: + # here we forward every context to compute p(c, l|x) = p(c|l, x) * p(l|x) + + features = features.expand((self.n_contexts, -1, -1, -1)) # C, B, T, F + embedded_left_contexts_ = embedded_left_contexts.reshape((self.n_contexts, 1, 1, -1)).expand( + (-1, *features.shape[1:3], -1) + ) # C, B, T, E + center_features = torch.cat((features, embedded_left_contexts_), -1) # C, B, T, F+E + center_logits = self.center_encoder(center_features) # C, B, T, F' + center_probs = F.log_softmax(center_logits, -1) + center_probs = center_probs.permute((1, 2, 3, 0)) # B, T, F', C + left_probs = F.log_softmax(left_logits, -1) + left_probs = left_probs.unsqueeze(-2) # B, T, 1, C + + joint_probs = center_probs + left_probs # B, T, F', C + joint_probs = torch.flatten(joint_probs, start_dim=2) # B, T, joint + + return DiphoneProbsV1( + embeddings_left_context=embedded_left_contexts, + output_left=left_probs.squeeze(), + output_center=joint_probs, + ) diff --git a/i6_models/parts/factored_hybrid/util.py b/i6_models/parts/factored_hybrid/util.py new file mode 100644 index 00000000..8b7f33b3 --- /dev/null +++ b/i6_models/parts/factored_hybrid/util.py @@ -0,0 +1,63 @@ +from enum import Enum +from typing import Callable, Union + +from torch import nn + + +class PhonemeStateClassV1(Enum): + """Phoneme state class augmentation selector""" + + none = 1 + word_end = 2 + boundary = 4 + + def factor(self): + return self.value + + +def get_center_dim( + n_contexts: int, + num_hmm_states_per_phone: int, + ph_class: Union[int, PhonemeStateClassV1], +) -> int: + """ + :return: number of center phonemes given the augmentation values + """ + + factor = ph_class.factor() if isinstance(ph_class, PhonemeStateClassV1) else ph_class + return n_contexts * num_hmm_states_per_phone * factor + + +def get_mlp( + num_input: int, + num_output: int, + hidden_dim: int, + dropout: float, + activation: Callable[[], nn.Module], + num_layers, +) -> nn.Module: + """ + :return: a context-mixing MLP according to the specifications + """ + + assert num_input > 0 + assert num_output > 0 + assert num_layers > 0 + assert hidden_dim > 0 + assert 0.0 <= dropout <= 1.0 + + return nn.Sequential( + *[ + layer + for in_dim in [ + num_input, + *[hidden_dim for _ in range(num_layers - 1)], + ] + for layer in [ + nn.Linear(in_dim, hidden_dim), + nn.Dropout(dropout), + activation(), + ] + ], + nn.Linear(hidden_dim, num_output, bias=True), + ) diff --git a/tests/test_fh.py b/tests/test_fh.py new file mode 100644 index 00000000..735f48e9 --- /dev/null +++ b/tests/test_fh.py @@ -0,0 +1,68 @@ +from itertools import product + +import torch +import torch.nn as nn + +from i6_models.parts.factored_hybrid import ( + DiphoneBackendV1, + DiphoneBackendV1Config, + PhonemeStateClassV1, + DiphoneLogitsV1, + DiphoneProbsV1, +) +from i6_models.parts.factored_hybrid.util import get_center_dim + + +def test_dim_calcs(): + n_ctx = 42 + + assert get_center_dim(n_ctx, 1, PhonemeStateClassV1.none) == 42 + assert get_center_dim(n_ctx, 1, PhonemeStateClassV1.word_end) == 84 + assert get_center_dim(n_ctx, 3, PhonemeStateClassV1.word_end) == 252 + assert get_center_dim(n_ctx, 3, PhonemeStateClassV1.boundary) == 504 + + +def test_output_shape(): + n_ctx = 42 + n_in = 32 + + for we_class, states_per_ph in product( + [PhonemeStateClassV1.none, PhonemeStateClassV1.word_end, PhonemeStateClassV1.boundary], + [1, 3], + ): + backend = DiphoneBackendV1( + DiphoneBackendV1Config( + activation=lambda: nn.ReLU(), + context_mix_mlp_dim=64, + context_mix_mlp_num_layers=2, + dropout=0.1, + left_context_embedding_dim=32, + n_contexts=n_ctx, + num_hmm_states_per_phone=states_per_ph, + num_inputs=n_in, + phoneme_state_class=we_class, + ) + ) + + backend.train(True) + for b, t in product([10, 50, 100], [10, 50, 100]): + contexts_forward = torch.randint(0, n_ctx, (b, t)) + encoder_output = torch.rand((b, t, n_in)) + output = backend(features=encoder_output, contexts_left=contexts_forward) + assert isinstance(output, DiphoneLogitsV1) and not isinstance(output, DiphoneProbsV1) + assert output.output_left.shape == (b, t, n_ctx) + cdim = get_center_dim(n_ctx, states_per_ph, we_class) + assert output.output_center.shape == (b, t, cdim) + + backend.train(False) + for b, t in product([10, 50, 100], [10, 50, 100]): + encoder_output = torch.rand((b, t, n_in)) + output = backend(features=encoder_output) + assert isinstance(output, DiphoneProbsV1) + assert output.output_left.shape == (b, t, n_ctx) + cdim = get_center_dim(n_ctx, states_per_ph, we_class) + assert output.output_center.shape == (b, t, cdim * n_ctx) + output_p = torch.exp(output.output_center) + ones_hopefully = torch.sum(output_p, dim=-1) + close_to_one = torch.abs(1 - ones_hopefully).flatten() < 1e-3 + assert all(close_to_one)