Skip to content

Commit

Permalink
Add FactoredDiphoneBlockV2 with right context output for training
Browse files Browse the repository at this point in the history
  • Loading branch information
NeoLegends committed Aug 20, 2024
1 parent 31c284d commit bff53f4
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 3 deletions.
8 changes: 7 additions & 1 deletion i6_models/parts/factored_hybrid/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
__all__ = ["FactoredDiphoneBlockV1Config", "FactoredDiphoneBlockV1", "BoundaryClassV1"]
__all__ = [
"FactoredDiphoneBlockV1Config",
"FactoredDiphoneBlockV1",
"FactoredDiphoneBlockV2Config",
"FactoredDiphoneBlockV2",
"BoundaryClassV1",
]

from .diphone import *
from .util import BoundaryClassV1
63 changes: 63 additions & 0 deletions i6_models/parts/factored_hybrid/diphone.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
__all__ = [
"FactoredDiphoneBlockV1Config",
"FactoredDiphoneBlockV1",
"FactoredDiphoneBlockV2Config",
"FactoredDiphoneBlockV2",
]

from dataclasses import dataclass
Expand Down Expand Up @@ -148,3 +150,64 @@ def forward_joint(self, features: Tensor) -> Tensor:
) # B, T, F'*C

return joint_log_probs


@dataclass
class FactoredDiphoneBlockV2Config(FactoredDiphoneBlockV1Config):
"""
Attributes:
Same attributes as parent class. In addition:
center_state_embedding_dim: embedding dimension of the center state
values. Good choice is in the order of num_center_states.
"""

center_state_embedding_dim: int

def __post_init__(self):
super().__post_init__()

assert self.center_state_embedding_dim > 0


class FactoredDiphoneBlockV2(FactoredDiphoneBlockV1):
"""
Like FactoredDiphoneBlockV1, but computes an additional diphone output on the right context `p(r|c,x)`.
This additional output is ignored when computing the joint, and only used in training.
"""

def __init__(self, cfg: FactoredDiphoneBlockV2Config):
super().__init__(cfg)

self.center_state_embedding = nn.Embedding(cfg.num_contexts, cfg.center_state_embedding_dim)
self.right_context_encoder = get_mlp(
num_input=cfg.num_inputs + cfg.center_state_embedding_dim,
num_output=self.num_contexts,
hidden_dim=cfg.context_mix_mlp_dim,
num_layers=cfg.context_mix_mlp_num_layers,
dropout=cfg.dropout,
activation=cfg.activation,
)

def forward_factored(
self,
features: Tensor, # B, T, F
contexts_left: Tensor, # B, T
contexts_center: Tensor, # B, T
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
"""
:param features: Main encoder output. shape B, T, F. F=num_inputs
:param contexts_left: The left contexts used to compute p(c|l,x), shape B, T.
:param contexts_center: The center states used to compute p(r|c,x), shape B, T.
:return: tuple of logits for p(c|l,x), p(l|x), p(r|c,x) and the embedded left context and center state values.
"""

logits_center, logits_left, contexts_embedded_left = super().forward_factored(features, contexts_left)

# in training we forward exactly one context per T, so: B, T, E
center_states_embedded = self.center_state_embedding(contexts_center)
features_right = torch.cat((features, center_states_embedded), -1) # B, T, F+E
logits_right = self.right_context_encoder(features_right)

return logits_center, logits_left, logits_right, contexts_embedded_left, center_states_embedded
43 changes: 41 additions & 2 deletions tests/test_fh.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
import torch
import torch.nn as nn

from i6_models.parts.factored_hybrid import BoundaryClassV1, FactoredDiphoneBlockV1, FactoredDiphoneBlockV1Config
from i6_models.parts.factored_hybrid import (
BoundaryClassV1,
FactoredDiphoneBlockV1,
FactoredDiphoneBlockV1Config,
FactoredDiphoneBlockV2,
FactoredDiphoneBlockV2Config,
)
from i6_models.parts.factored_hybrid.util import get_center_dim


Expand All @@ -16,7 +22,7 @@ def test_dim_calcs():
assert get_center_dim(n_ctx, 3, BoundaryClassV1.boundary) == 504


def test_output_shape_and_norm():
def test_v1_output_shape_and_norm():
n_ctx = 42
n_in = 32

Expand Down Expand Up @@ -54,3 +60,36 @@ def test_output_shape_and_norm():
ones_hopefully = torch.sum(output_p, dim=-1)
close_to_one = torch.abs(1 - ones_hopefully).flatten() < 1e-3
assert all(close_to_one)


def test_v2_output_shape_and_norm():
n_ctx = 42
n_in = 32

for we_class, states_per_ph in product(
[BoundaryClassV1.none, BoundaryClassV1.word_end, BoundaryClassV1.boundary],
[1, 3],
):
block = FactoredDiphoneBlockV2(
FactoredDiphoneBlockV2Config(
activation=nn.ReLU,
context_mix_mlp_dim=64,
context_mix_mlp_num_layers=2,
dropout=0.1,
left_context_embedding_dim=32,
left_context_embedding_dim=128,
num_contexts=n_ctx,
num_hmm_states_per_phone=states_per_ph,
num_inputs=n_in,
boundary_class=we_class,
)
)

for b, t in product([10, 50, 100], [10, 50, 100]):
contexts_forward = torch.randint(0, n_ctx, (b, t))
encoder_output = torch.rand((b, t, n_in))
output_center, output_left, output_right, _, _ = block(features=encoder_output, contexts_left=contexts_forward)
assert output_left.shape == (b, t, n_ctx)
assert output_right.shape == (b, t, n_ctx)
cdim = get_center_dim(n_ctx, states_per_ph, we_class)
assert output_center.shape == (b, t, cdim)

0 comments on commit bff53f4

Please sign in to comment.