You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In the current implementation, the forward() method is generic for train or eval mode. In some case, we need to have not only the loss but the prediction on output that allow to compute extra features like the SDR metric during the validation step.
Because the loss function code is common for BSRoformer and MelBandRoformer classes, maybe that can be better create a new class like MultiResLoss for a maximum of flexibility:
importtorchimporttorch.nn.functionalasFfromeinopsimportrearrangefrombeartypeimportbeartypefrombeartype.typingimportTupleclassMultiResLoss():
@beartypedef__init__(
self,
num_stems,
stft_n_fft,
multi_stft_resolution_loss_weight=1.,
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
multi_stft_hop_size=147,
multi_stft_normalized=False
):
self.num_stems=num_stemsself.multi_stft_resolution_loss_weight=multi_stft_resolution_loss_weightself.multi_stft_resolutions_window_sizes=multi_stft_resolutions_window_sizesself.multi_stft_n_fft=stft_n_fftself.multi_stft_kwargs=dict(
hop_length=multi_stft_hop_size,
normalized=multi_stft_normalized
)
def__call__(
self,
predict,
targets,
return_loss_breakdown=False
):
ifself.num_stems>1:
asserttargets.ndim==4andtargets.shape[1] ==self.num_stemsiftargets.ndim==2:
targets=rearrange(targets, '... t -> ... 1 t')
targets=targets[..., :predict.shape[-1]] # protect against lost length on istftloss=F.l1_loss(predict, targets)
multi_stft_resolution_loss=0.forwindow_sizeinself.multi_stft_resolutions_window_sizes:
res_stft_kwargs=dict(
n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stftwin_length=window_size,
return_complex=True,
**self.multi_stft_kwargs,
)
predict_Y=torch.stft(rearrange(predict, '... s t -> (... s) t'), **res_stft_kwargs)
targets_Y=torch.stft(rearrange(targets, '... s t -> (... s) t'), **res_stft_kwargs)
multi_stft_resolution_loss=multi_stft_resolution_loss+F.l1_loss(predict_Y, targets_Y)
weighted_multi_resolution_loss=multi_stft_resolution_loss*self.multi_stft_resolution_loss_weighttotal_loss=loss+weighted_multi_resolution_lossifnotreturn_loss_breakdown:
returntotal_lossreturntotal_loss, (loss, multi_stft_resolution_loss)
In the same spirit, a little refactoring could be to create a new file for the common classes :
In the current implementation, the
forward()
method is generic for train or eval mode. In some case, we need to have not only the loss but the prediction on output that allow to compute extra features like the SDR metric during the validation step.Because the loss function code is common for
BSRoformer
andMelBandRoformer
classes, maybe that can be better create a new class likeMultiResLoss
for a maximum of flexibility:In the same spirit, a little refactoring could be to create a new file for the common classes :
That can be easier for future change in the code?
The text was updated successfully, but these errors were encountered: