-
Notifications
You must be signed in to change notification settings - Fork 170
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
L1 loss added to the models #85
base: main
Are you sure you want to change the base?
Conversation
the values are clamped before sqrt to avoid NaN during training, as done in the original implementation https://github.com/nicola-decao/s-vae-pytorch/blob/master/hyperspherical_vae/distributions/von_mises_fisher.py#L68
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @soumickmj,
Thank you very much for this contribution and sorry for the late reply. Please see some comments that, in my opinion, should be addressed before merging.
- First, we need to allow to pass
l1
as argument of the VAEConfig (see comment onvae_config.py
) - Second, since the l1 loss does not really refer to a distribution (e.g. BCE = Bernoulli, MSE = multivariate standard), we should not allow the computation of the NLL when such a loss is chosen. Instead, I propose to add an exception is
get_nll
method is called with l1 as loss as follows
raise NotImplementedError("Computation of the likelihood is not implemented when `L1 loss` is chosen")
I any case, do not hesitate if you have any questions or do not agree with the proposed modifications.
Best,
Clément
@@ -286,6 +294,18 @@ def get_nll(self, data, n_samples=1, batch_size=100): | |||
reduction="none", | |||
).sum(dim=-1) | |||
|
|||
elif self.model_config.reconstruction_loss == "l1": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The computation of the likelihood cannot be handled with the l1 loss since it does not correspond to a tractable distribution per say.
@@ -572,6 +586,18 @@ def get_nll(self, data, n_samples=1, batch_size=100): | |||
reduction="none", | |||
).sum(dim=-1) | |||
|
|||
elif self.model_config.reconstruction_loss == "l1": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The computation of the likelihood cannot be handled with the l1 loss since it does not correspond to a tractable distribution per say.
src/pythae/models/vae/vae_config.py
Outdated
@@ -11,7 +11,7 @@ class VAEConfig(BaseAEConfig): | |||
Parameters: | |||
input_dim (tuple): The input_data dimension. | |||
latent_dim (int): The latent space dimension. Default: None. | |||
reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse' | |||
reconstruction_loss (str): The reconstruction loss to use ['bce', 'l1', 'mse']. Default: 'mse' | |||
""" | |||
|
|||
reconstruction_loss: Literal["bce", "mse"] = "mse" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be replaced by
reconstruction_loss: Literal["bce", "mse", "l1"] = "mse"
@@ -424,6 +424,20 @@ def _log_p_x_given_z(self, recon_x, x): | |||
reduction="none", | |||
).sum(dim=-1) | |||
|
|||
elif self.model_config.reconstruction_loss == "l1": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this model, this should correspond to a known distribution. The mse
actually models a multivariate normal. In the case of L1 loss, there is not really an underlying distribution so we should not allow the L1 for this model.
@@ -9,6 +9,7 @@ class RHVAEConfig(VAEConfig): | |||
|
|||
Parameters: | |||
latent_dim (int): The latent dimension used for the latent space. Default: 10 | |||
reconstruction_loss (str): The reconstruction loss to use ['bce', 'l1', 'mse']. Default: 'mse' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
l1 cannot be allowed in this model (see next comment)
@@ -186,6 +194,18 @@ def get_nll(self, data, n_samples=1, batch_size=100): | |||
reduction="none", | |||
).sum(dim=-1) | |||
|
|||
elif self.model_config.reconstruction_loss == "l1": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not compute the likelihood when l1 loss is chosen since there is no associated distribution. We may add a warning with the following message
raise NotImplementedError("Computation of the likelihood is not implemented when `L1 loss` is chosen")
@@ -212,6 +220,18 @@ def get_nll(self, data, n_samples=1, batch_size=100): | |||
reduction="none", | |||
).sum(dim=-1) | |||
|
|||
elif self.model_config.reconstruction_loss == "l1": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as before for the computation of the NLL.
@@ -226,6 +234,18 @@ def get_nll(self, data, n_samples=1, batch_size=100): | |||
reduction="none", | |||
).sum(dim=-1) | |||
|
|||
elif self.model_config.reconstruction_loss == "l1": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as before for the computation of the NLL.
@@ -243,6 +251,18 @@ def get_nll(self, data, n_samples=1, batch_size=100): | |||
reduction="none", | |||
).sum(dim=-1) | |||
|
|||
elif self.model_config.reconstruction_loss == "l1": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as before for the computation of the NLL.
@@ -38,6 +38,9 @@ def model_configs_no_input_dim(request): | |||
RHVAEConfig( | |||
input_dim=(1, 28, 28), latent_dim=1, n_lf=1, reconstruction_loss="bce" | |||
), | |||
RHVAEConfig( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not needed since this model will not handled L1 loss
Custom loss plan cancelled
MSE loss sometimes make the models produce smooth images. L1 loss is an easy drop-in fix for the same. I have added to the models where we already had MSE and BCE loss functions, and skipped (for now) the ones without the recon loss flag. In the futurue, I will also add a feature to pass custom loss functions.