Skip to content

Commit

Permalink
override forward_factored instead
Browse files Browse the repository at this point in the history
  • Loading branch information
NeoLegends committed Aug 23, 2024
1 parent 815934c commit 3de2552
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
8 changes: 6 additions & 2 deletions i6_models/parts/factored_hybrid/triphone.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ def __init__(self, cfg: FactoredTriphoneBlockV1Config):
activation=cfg.activation,
)

def forward(
# update type definitions
def forward(self, *args, **kwargs) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
return super().forward(*args, **kwargs)

def forward_factored(
self,
features: Tensor, # B, T, F
contexts_left: Tensor, # B, T
Expand All @@ -58,7 +62,7 @@ def forward(
:return: tuple of logits for p(c|l,x), p(l|x), p(r|c,l,x) and the embedded left context and center state values.
"""

logits_center, logits_left, contexts_left_embedded = super().forward(features, contexts_left)
logits_center, logits_left, contexts_left_embedded = super().forward_factored(features, contexts_left)

# This logic is very similar to FactoredDiphoneBlockV2.forward, but not the same.
# This class computes `p(r|c,l,h(x))` while FactoredDiphoneBlockV2 computes `p(r|c,h(x))`.
Expand Down
4 changes: 1 addition & 3 deletions tests/test_fh.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,7 @@ def test_tri_output_shape_and_norm():
contexts_center = torch.randint(0, tri_block.num_center, (b, t))
encoder_output = torch.rand((b, t, n_in))
output_center, output_left, output_right, _, _ = tri_block(
features=encoder_output,
contexts_left=contexts_left,
contexts_center=contexts_center,
features=encoder_output, contexts_left=contexts_left, contexts_center=contexts_center
)
assert output_left.shape == (b, t, n_ctx)
assert output_center.shape == (b, t, cdim)
Expand Down

0 comments on commit 3de2552

Please sign in to comment.