diff --git a/i6_models/parts/frontend/generic_frontend.py b/i6_models/parts/frontend/generic_frontend.py index 1e8c6a47..6757b50c 100644 --- a/i6_models/parts/frontend/generic_frontend.py +++ b/i6_models/parts/frontend/generic_frontend.py @@ -85,8 +85,8 @@ def check_valid(self): assert len(self.layer_ordering) == num_convs + num_pools + num_activations, "Number of total layers mismatch!" - for kernel_sizes in filter(None, [self.conv_kernel_sizes, self.pool_kernel_sizes]): - for kernel_size in kernel_sizes: + if self.conv_kernel_sizes is not None: + for kernel_size in self.conv_kernel_sizes: assert all(k % 2 for k in kernel_size), "ConformerVGGFrontendV1 only supports odd kernel sizes" def __post__init__(self): @@ -132,7 +132,7 @@ def __init__(self, model_cfg: GenericFrontendV1Config): if layer_type == FrontendLayerType.Conv2d: conv_out_dim = model_cfg.conv_out_dims[conv_layer_index] conv_kernel_size = model_cfg.conv_kernel_sizes[conv_layer_index] - conv_stride = 1 if model_cfg.conv_strides is None else model_cfg.conv_strides[conv_layer_index] + conv_stride = (1, 1) if model_cfg.conv_strides is None else model_cfg.conv_strides[conv_layer_index] conv_padding = ( get_same_padding(conv_kernel_size) if model_cfg.conv_paddings is None @@ -177,7 +177,7 @@ def __init__(self, model_cfg: GenericFrontendV1Config): last_feat_dim = calculate_output_dim( in_dim=last_feat_dim, filter_size=pool_kernel_size[1], - stride=pool_stride[1] or pool_kernel_size[1], + stride=(pool_stride or pool_kernel_size)[1], padding=pool_padding[1], ) pool_layer_index += 1