Skip to content

Commit

Permalink
RF specaugment, allow to disable time/freq mask
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Feb 12, 2025
1 parent 8640885 commit 51918eb
Showing 1 changed file with 20 additions and 18 deletions.
38 changes: 20 additions & 18 deletions returnn/frontend/audio/specaugment.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,25 +44,27 @@ def _mask_branch():
x_masked = x
spatial_len = spatial_dim.get_dim_value_tensor()
# time mask
x_masked = random_mask(
x_masked,
mask_axis=spatial_dim,
broadcast_axis=feature_dim,
min_num=rf.minimum(step1 + step2, spatial_len),
max_num=rf.minimum(
rf.maximum(spatial_len // num_spatial_mask_factor, 2) * (step0 + step1 + step2 * 2), spatial_len
),
max_dims=max_consecutive_spatial_dims,
)
if max_consecutive_spatial_dims > 0 and num_spatial_mask_factor > 0:
x_masked = random_mask(
x_masked,
mask_axis=spatial_dim,
broadcast_axis=feature_dim,
min_num=rf.minimum(step1 + step2, spatial_len),
max_num=rf.minimum(
rf.maximum(spatial_len // num_spatial_mask_factor, 2) * (step0 + step1 + step2 * 2), spatial_len
),
max_dims=max_consecutive_spatial_dims,
)
# feature mask
x_masked = random_mask(
x_masked,
mask_axis=feature_dim,
broadcast_axis=spatial_dim,
min_num=step1 + step2,
max_num=step0 * 2 + step1 + step2 * 2,
max_dims=max_consecutive_feature_dims,
)
if max_consecutive_feature_dims > 0:
x_masked = random_mask(
x_masked,
mask_axis=feature_dim,
broadcast_axis=spatial_dim,
min_num=step1 + step2,
max_num=step0 * 2 + step1 + step2 * 2,
max_dims=max_consecutive_feature_dims,
)
return x_masked

return rf.cond(rf.get_run_ctx().train_flag | (not only_on_train), _mask_branch, lambda: x)
Expand Down

0 comments on commit 51918eb

Please sign in to comment.