diff --git a/i6_models/parts/factored_hybrid/__init__.py b/i6_models/parts/factored_hybrid/__init__.py index 3d9da599..7184ae66 100644 --- a/i6_models/parts/factored_hybrid/__init__.py +++ b/i6_models/parts/factored_hybrid/__init__.py @@ -1,4 +1,10 @@ -__all__ = ["FactoredDiphoneBlockV1Config", "FactoredDiphoneBlockV1", "BoundaryClassV1"] +__all__ = [ + "FactoredDiphoneBlockV1Config", + "FactoredDiphoneBlockV1", + "FactoredDiphoneBlockV2Config", + "FactoredDiphoneBlockV2", + "BoundaryClassV1", +] from .diphone import * from .util import BoundaryClassV1 diff --git a/i6_models/parts/factored_hybrid/diphone.py b/i6_models/parts/factored_hybrid/diphone.py index e21ae423..9b5e69cb 100644 --- a/i6_models/parts/factored_hybrid/diphone.py +++ b/i6_models/parts/factored_hybrid/diphone.py @@ -1,6 +1,8 @@ __all__ = [ "FactoredDiphoneBlockV1Config", "FactoredDiphoneBlockV1", + "FactoredDiphoneBlockV2Config", + "FactoredDiphoneBlockV2", ] from dataclasses import dataclass @@ -148,3 +150,64 @@ def forward_joint(self, features: Tensor) -> Tensor: ) # B, T, F'*C return joint_log_probs + + +@dataclass +class FactoredDiphoneBlockV2Config(FactoredDiphoneBlockV1Config): + """ + Attributes: + Same attributes as parent class. In addition: + + center_state_embedding_dim: embedding dimension of the center state + values. Good choice is in the order of num_center_states. + """ + + center_state_embedding_dim: int + + def __post_init__(self): + super().__post_init__() + + assert self.center_state_embedding_dim > 0 + + +class FactoredDiphoneBlockV2(FactoredDiphoneBlockV1): + """ + Like FactoredDiphoneBlockV1, but computes an additional diphone output on the right context `p(r|c,x)`. + + This additional output is ignored when computing the joint, and only used in training. + """ + + def __init__(self, cfg: FactoredDiphoneBlockV2Config): + super().__init__(cfg) + + self.center_state_embedding = nn.Embedding(cfg.num_contexts, cfg.center_state_embedding_dim) + self.right_context_encoder = get_mlp( + num_input=cfg.num_inputs + cfg.center_state_embedding_dim, + num_output=self.num_contexts, + hidden_dim=cfg.context_mix_mlp_dim, + num_layers=cfg.context_mix_mlp_num_layers, + dropout=cfg.dropout, + activation=cfg.activation, + ) + + def forward_factored( + self, + features: Tensor, # B, T, F + contexts_left: Tensor, # B, T + contexts_center: Tensor, # B, T + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + """ + :param features: Main encoder output. shape B, T, F. F=num_inputs + :param contexts_left: The left contexts used to compute p(c|l,x), shape B, T. + :param contexts_center: The center states used to compute p(r|c,x), shape B, T. + :return: tuple of logits for p(c|l,x), p(l|x), p(r|c,x) and the embedded left context and center state values. + """ + + logits_center, logits_left, contexts_embedded_left = super().forward_factored(features, contexts_left) + + # in training we forward exactly one context per T, so: B, T, E + center_states_embedded = self.center_state_embedding(contexts_center) + features_right = torch.cat((features, center_states_embedded), -1) # B, T, F+E + logits_right = self.right_context_encoder(features_right) + + return logits_center, logits_left, logits_right, contexts_embedded_left, center_states_embedded diff --git a/tests/test_fh.py b/tests/test_fh.py index 224362de..d812c690 100644 --- a/tests/test_fh.py +++ b/tests/test_fh.py @@ -3,7 +3,13 @@ import torch import torch.nn as nn -from i6_models.parts.factored_hybrid import BoundaryClassV1, FactoredDiphoneBlockV1, FactoredDiphoneBlockV1Config +from i6_models.parts.factored_hybrid import ( + BoundaryClassV1, + FactoredDiphoneBlockV1, + FactoredDiphoneBlockV1Config, + FactoredDiphoneBlockV2, + FactoredDiphoneBlockV2Config, +) from i6_models.parts.factored_hybrid.util import get_center_dim @@ -16,7 +22,7 @@ def test_dim_calcs(): assert get_center_dim(n_ctx, 3, BoundaryClassV1.boundary) == 504 -def test_output_shape_and_norm(): +def test_v1_output_shape_and_norm(): n_ctx = 42 n_in = 32 @@ -54,3 +60,36 @@ def test_output_shape_and_norm(): ones_hopefully = torch.sum(output_p, dim=-1) close_to_one = torch.abs(1 - ones_hopefully).flatten() < 1e-3 assert all(close_to_one) + + +def test_v2_output_shape_and_norm(): + n_ctx = 42 + n_in = 32 + + for we_class, states_per_ph in product( + [BoundaryClassV1.none, BoundaryClassV1.word_end, BoundaryClassV1.boundary], + [1, 3], + ): + block = FactoredDiphoneBlockV2( + FactoredDiphoneBlockV2Config( + activation=nn.ReLU, + context_mix_mlp_dim=64, + context_mix_mlp_num_layers=2, + dropout=0.1, + left_context_embedding_dim=32, + left_context_embedding_dim=128, + num_contexts=n_ctx, + num_hmm_states_per_phone=states_per_ph, + num_inputs=n_in, + boundary_class=we_class, + ) + ) + + for b, t in product([10, 50, 100], [10, 50, 100]): + contexts_forward = torch.randint(0, n_ctx, (b, t)) + encoder_output = torch.rand((b, t, n_in)) + output_center, output_left, output_right, _, _ = block(features=encoder_output, contexts_left=contexts_forward) + assert output_left.shape == (b, t, n_ctx) + assert output_right.shape == (b, t, n_ctx) + cdim = get_center_dim(n_ctx, states_per_ph, we_class) + assert output_center.shape == (b, t, cdim)