diff --git a/i6_models/parts/factored_hybrid/triphone.py b/i6_models/parts/factored_hybrid/triphone.py index 3070ef17..4b26b403 100644 --- a/i6_models/parts/factored_hybrid/triphone.py +++ b/i6_models/parts/factored_hybrid/triphone.py @@ -42,7 +42,11 @@ def __init__(self, cfg: FactoredTriphoneBlockV1Config): activation=cfg.activation, ) - def forward( + # update type definitions + def forward(self, *args, **kwargs) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + return super().forward(*args, **kwargs) + + def forward_factored( self, features: Tensor, # B, T, F contexts_left: Tensor, # B, T @@ -58,7 +62,7 @@ def forward( :return: tuple of logits for p(c|l,x), p(l|x), p(r|c,l,x) and the embedded left context and center state values. """ - logits_center, logits_left, contexts_left_embedded = super().forward(features, contexts_left) + logits_center, logits_left, contexts_left_embedded = super().forward_factored(features, contexts_left) # This logic is very similar to FactoredDiphoneBlockV2.forward, but not the same. # This class computes `p(r|c,l,h(x))` while FactoredDiphoneBlockV2 computes `p(r|c,h(x))`. diff --git a/tests/test_fh.py b/tests/test_fh.py index c42954b7..39c743f7 100644 --- a/tests/test_fh.py +++ b/tests/test_fh.py @@ -129,9 +129,7 @@ def test_tri_output_shape_and_norm(): contexts_center = torch.randint(0, tri_block.num_center, (b, t)) encoder_output = torch.rand((b, t, n_in)) output_center, output_left, output_right, _, _ = tri_block( - features=encoder_output, - contexts_left=contexts_left, - contexts_center=contexts_center, + features=encoder_output, contexts_left=contexts_left, contexts_center=contexts_center ) assert output_left.shape == (b, t, n_ctx) assert output_center.shape == (b, t, cdim)