diff --git a/tests/test_conformer.py b/tests/test_conformer.py index 0bd25424..29693c82 100644 --- a/tests/test_conformer.py +++ b/tests/test_conformer.py @@ -37,7 +37,7 @@ def get_output_shape(batch, time, features, norm=None, kernel_size=31, dropout=0 def test_ConformerPositionwiseFeedForwardV1(): def get_output_shape(input_shape, input_dim, hidden_dim, dropout, activation): x = torch.randn(input_shape) - cfg = ConformerPositionwiseFeedForwardV2Config( + cfg = ConformerPositionwiseFeedForwardV1Config( input_dim=input_dim, hidden_dim=hidden_dim, dropout=dropout, activation=activation ) conf_ffn_part = ConformerPositionwiseFeedForwardV1(cfg)