Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature request: decouple the loss function of the forward function #22

Open
dorpxam opened this issue Oct 26, 2023 · 2 comments
Open

Comments

@dorpxam
Copy link

dorpxam commented Oct 26, 2023

In the current implementation, the forward() method is generic for train or eval mode. In some case, we need to have not only the loss but the prediction on output that allow to compute extra features like the SDR metric during the validation step.

Because the loss function code is common for BSRoformer and MelBandRoformer classes, maybe that can be better create a new class like MultiResLoss for a maximum of flexibility:

import torch
import torch.nn.functional as F
from einops import rearrange
from beartype import beartype
from beartype.typing import Tuple

class MultiResLoss():
    @beartype
    def __init__(
        self,
        num_stems,
        stft_n_fft,
        multi_stft_resolution_loss_weight = 1.,
        multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
        multi_stft_hop_size = 147,
        multi_stft_normalized = False
    ):
        self.num_stems = num_stems

        self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
        self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
        self.multi_stft_n_fft = stft_n_fft

        self.multi_stft_kwargs = dict(
            hop_length = multi_stft_hop_size,
            normalized = multi_stft_normalized
        )
    
    def __call__(
        self, 
        predict, 
        targets, 
        return_loss_breakdown = False
    ):
        if self.num_stems > 1:
            assert targets.ndim == 4 and targets.shape[1] == self.num_stems
        
        if targets.ndim == 2:
            targets = rearrange(targets, '... t -> ... 1 t')

        targets = targets[..., :predict.shape[-1]] # protect against lost length on istft

        loss = F.l1_loss(predict, targets)

        multi_stft_resolution_loss = 0.

        for window_size in self.multi_stft_resolutions_window_sizes:

            res_stft_kwargs = dict(
                n_fft = max(window_size, self.multi_stft_n_fft),  # not sure what n_fft is across multi resolution stft
                win_length = window_size,
                return_complex = True,
                **self.multi_stft_kwargs,
            )

            predict_Y = torch.stft(rearrange(predict, '... s t -> (... s) t'), **res_stft_kwargs)
            targets_Y = torch.stft(rearrange(targets, '... s t -> (... s) t'), **res_stft_kwargs)

            multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(predict_Y, targets_Y)

        weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight

        total_loss =  loss + weighted_multi_resolution_loss

        if not return_loss_breakdown:
            return total_loss

        return total_loss, (loss, multi_stft_resolution_loss)

In the same spirit, a little refactoring could be to create a new file for the common classes :

- RMSNorm
- FeedForward
- Attention
- Transformer
- BandSplit
- MLP
- MaskEstimator

That can be easier for future change in the code?

@turian
Copy link

turian commented Jun 10, 2024

I agree that separating the loss would be useful, because sometimes you just want to apply forward to get the output

@turian
Copy link

turian commented Jun 10, 2024

And also you have custom losses

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants