Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

L1 loss added to the models #85

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/pythae/models/adversarial_ae/adversarial_ae_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class Adversarial_AE_Config(VAEConfig):
Parameters:
input_dim (tuple): The input_data dimension.
latent_dim (int): The latent space dimension. Default: None.
reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse'
reconstruction_loss (str): The reconstruction loss to use ['bce', 'l1', 'mse']. Default: 'mse'
adversarial_loss_scale (float): Parameter scaling the adversarial loss. Default: 0.5
reconstruction_loss_scale (float): Parameter scaling the reconstruction loss. Default: 1
deterministic_posterior (bool): Whether to use a deterministic posterior (Dirac). Default:
Expand Down
8 changes: 8 additions & 0 deletions src/pythae/models/adversarial_ae/adversarial_ae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,14 @@ def loss_function(self, recon_x, x, z, z_prior):
reduction="none",
).sum(dim=-1)

elif self.model_config.reconstruction_loss == "l1":

recon_loss = F.l1_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
).sum(dim=-1)

gen_adversarial_score = self.discriminator(z).embedding.flatten()
prior_adversarial_score = self.discriminator(z_prior).embedding.flatten()

Expand Down
2 changes: 1 addition & 1 deletion src/pythae/models/beta_tc_vae/beta_tc_vae_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class BetaTCVAEConfig(VAEConfig):
Parameters:
input_dim (tuple): The input_data dimension.
latent_dim (int): The latent space dimension. Default: None.
reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse'
reconstruction_loss (str): The reconstruction loss to use ['bce', 'l1', 'mse']. Default: 'mse'
alpha (float): The balancing factor before the Index code Mutual Info. Default: 1
beta (float): The balancing factor before the Total Correlation. Default: 1
gamma (float): The balancing factor before the dimension-wise KL. Default: 1
Expand Down
8 changes: 8 additions & 0 deletions src/pythae/models/beta_tc_vae/beta_tc_vae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,14 @@ def loss_function(self, recon_x, x, mu, log_var, z, dataset_size):
reduction="none",
).sum(dim=-1)

elif self.model_config.reconstruction_loss == "l1":

recon_loss = F.l1_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
).sum(dim=-1)

log_q_z_given_x = self._compute_log_gauss_density(z, mu, log_var).sum(
dim=-1
) # [B]
Expand Down
2 changes: 1 addition & 1 deletion src/pythae/models/beta_vae/beta_vae_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class BetaVAEConfig(VAEConfig):
Parameters:
input_dim (tuple): The input_data dimension.
latent_dim (int): The latent space dimension. Default: None.
reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse'
reconstruction_loss (str): The reconstruction loss to use ['bce', 'l1', 'mse']. Default: 'mse'
beta (float): The balancing factor. Default: 1
"""

Expand Down
8 changes: 8 additions & 0 deletions src/pythae/models/beta_vae/beta_vae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@ def loss_function(self, recon_x, x, mu, log_var, z):
reduction="none",
).sum(dim=-1)

elif self.model_config.reconstruction_loss == "l1":

recon_loss = F.l1_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
).sum(dim=-1)

KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=-1)

return (
Expand Down
2 changes: 1 addition & 1 deletion src/pythae/models/ciwae/ciwae_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class CIWAEConfig(VAEConfig):
Parameters:
input_dim (tuple): The input_data dimension.
latent_dim (int): The latent space dimension. Default: None.
reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse'
reconstruction_loss (str): The reconstruction loss to use ['bce', 'l1', 'mse']. Default: 'mse'
number_samples (int): Number of samples to use on the Monte-Carlo estimation
beta (float): The value of the factor in the convex combination of the VAE and IWAE ELBO.
Default: 0.5.
Expand Down
10 changes: 10 additions & 0 deletions src/pythae/models/ciwae/ciwae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,16 @@ def loss_function(self, recon_x, x, mu, log_var, z):
reduction="none",
).sum(dim=-1)

elif self.model_config.reconstruction_loss == "l1":

recon_loss = F.l1_loss(
recon_x,
x.reshape(recon_x.shape[0], -1)
.unsqueeze(1)
.repeat(1, self.n_samples, 1),
reduction="none",
).sum(dim=-1)

log_q_z = (-0.5 * (log_var + torch.pow(z - mu, 2) / log_var.exp())).sum(dim=-1)
log_p_z = -0.5 * (z ** 2).sum(dim=-1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class DisentangledBetaVAEConfig(VAEConfig):
Parameters:
input_dim (tuple): The input_data dimension.
latent_dim (int): The latent space dimension. Default: None.
reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse'
reconstruction_loss (str): The reconstruction loss to use ['bce', 'l1', 'mse']. Default: 'mse'
beta (float): The balancing factor. Default: 10.
C (float): The value of the KL divergence term of the ELBO we wish to approach, measured in
nats. Default: 50.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,14 @@ def loss_function(self, recon_x, x, mu, log_var, z, epoch):
reduction="none",
).sum(dim=-1)

elif self.model_config.reconstruction_loss == "l1":

recon_loss = F.l1_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
).sum(dim=-1)

KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=-1)
C_factor = min(epoch / (self.warmup_epoch + 1), 1)
KLD_diff = torch.abs(KLD - self.C * C_factor)
Expand Down
2 changes: 1 addition & 1 deletion src/pythae/models/factor_vae/factor_vae_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class FactorVAEConfig(VAEConfig):
Parameters:
input_dim (tuple): The input_data dimension.
latent_dim (int): The latent space dimension. Default: None.
reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse'
reconstruction_loss (str): The reconstruction loss to use ['bce', 'l1', 'mse']. Default: 'mse'
gamma (float): The balancing factor before the Total Correlation. Default: 0.5
"""
gamma: float = 2.0
Expand Down
8 changes: 8 additions & 0 deletions src/pythae/models/factor_vae/factor_vae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,14 @@ def loss_function(self, recon_x, x, mu, log_var, z, z_bis_permuted):
reduction="none",
).sum(dim=-1)

elif self.model_config.reconstruction_loss == "l1":

recon_loss = F.l1_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
).sum(dim=-1)

KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=-1)

latent_adversarial_score = self.discriminator(z)
Expand Down
2 changes: 1 addition & 1 deletion src/pythae/models/info_vae/info_vae_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class INFOVAE_MMD_Config(VAEConfig):
Parameters:
input_dim (tuple): The input_data dimension.
latent_dim (int): The latent space dimension. Default: None.
reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse'
reconstruction_loss (str): The reconstruction loss to use ['bce', 'l1', 'mse']. Default: 'mse'
kernel_choice (str): The kernel to choose. Available options are ['rbf', 'imq'] i.e.
radial basis functions or inverse multiquadratic kernel. Default: 'imq'.
alpha (float): The alpha factor balancing the weigth: Default: 0.5
Expand Down
8 changes: 8 additions & 0 deletions src/pythae/models/info_vae/info_vae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,14 @@ def loss_function(self, recon_x, x, z, z_prior, mu, log_var):
reduction="none",
).sum(dim=-1)

elif self.model_config.reconstruction_loss == "l1":

recon_loss = F.l1_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
).sum(dim=-1)

KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=-1)

if self.kernel_choice == "rbf":
Expand Down
2 changes: 1 addition & 1 deletion src/pythae/models/iwae/iwae_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class IWAEConfig(VAEConfig):
Parameters:
input_dim (tuple): The input_data dimension.
latent_dim (int): The latent space dimension. Default: None.
reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse'
reconstruction_loss (str): The reconstruction loss to use ['bce', 'l1', 'mse']. Default: 'mse'
number_samples (int): Number of samples to use on the Monte-Carlo estimation. Default: 10
"""

Expand Down
10 changes: 10 additions & 0 deletions src/pythae/models/iwae/iwae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,16 @@ def loss_function(self, recon_x, x, mu, log_var, z):
reduction="none",
).sum(dim=-1)

elif self.model_config.reconstruction_loss == "l1":

recon_loss = F.l1_loss(
recon_x,
x.reshape(recon_x.shape[0], -1)
.unsqueeze(1)
.repeat(1, self.n_samples, 1),
reduction="none",
).sum(dim=-1)

log_q_z = (-0.5 * (log_var + torch.pow(z - mu, 2) / log_var.exp())).sum(dim=-1)
log_p_z = -0.5 * (z ** 2).sum(dim=-1)

Expand Down
2 changes: 1 addition & 1 deletion src/pythae/models/miwae/miwae_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class MIWAEConfig(VAEConfig):
Parameters:
input_dim (tuple): The input_data dimension.
latent_dim (int): The latent space dimension. Default: None.
reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse'
reconstruction_loss (str): The reconstruction loss to use ['bce', 'l1', 'mse']. Default: 'mse'
number_gradient_estimates (int): Number of (M-)estimates to use for the gradient
estimate. Default: 5
number_samples (int): Number of samples to use on the Monte-Carlo estimation. Default: 10
Expand Down
11 changes: 11 additions & 0 deletions src/pythae/models/miwae/miwae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,17 @@ def loss_function(self, recon_x, x, mu, log_var, z):
reduction="none",
).sum(dim=-1)

elif self.model_config.reconstruction_loss == "l1":

recon_loss = F.l1_loss(
recon_x,
x.reshape(recon_x.shape[0], -1)
.unsqueeze(1)
.unsqueeze(1)
.repeat(1, self.gradient_n_estimates, self.n_samples, 1),
reduction="none",
).sum(dim=-1)

log_q_z = (-0.5 * (log_var + torch.pow(z - mu, 2) / log_var.exp())).sum(dim=-1)
log_p_z = -0.5 * (z ** 2).sum(dim=-1)

Expand Down
2 changes: 1 addition & 1 deletion src/pythae/models/piwae/piwae_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class PIWAEConfig(VAEConfig):
Parameters:
input_dim (tuple): The input_data dimension.
latent_dim (int): The latent space dimension. Default: None.
reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse'
reconstruction_loss (str): The reconstruction loss to use ['bce', 'l1', 'mse']. Default: 'mse'
number_gradient_estimates (int): Number of (M-)estimates to use for the gradient
estimate for the encoder. Default: 5
number_samples (int): Number of samples to use on the Monte-Carlo estimation. Default: 10
Expand Down
11 changes: 11 additions & 0 deletions src/pythae/models/piwae/piwae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,17 @@ def loss_function(self, recon_x, x, mu, log_var, z):
reduction="none",
).sum(dim=-1)

elif self.model_config.reconstruction_loss == "l1":

recon_loss = F.l1_loss(
recon_x,
x.reshape(recon_x.shape[0], -1)
.unsqueeze(1)
.unsqueeze(1)
.repeat(1, self.gradient_n_estimates, self.n_samples, 1),
reduction="none",
).sum(dim=-1)

log_q_z = (-0.5 * (log_var + torch.pow(z - mu, 2) / log_var.exp())).sum(dim=-1)
log_p_z = -0.5 * (z ** 2).sum(dim=-1)

Expand Down
2 changes: 1 addition & 1 deletion src/pythae/models/pvae/pvae_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class PoincareVAEConfig(VAEConfig):
Parameters:
input_dim (tuple): The input_data dimension.
latent_dim (int): The latent space dimension. Default: None.
reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse'
reconstruction_loss (str): The reconstruction loss to use ['bce', 'l1', 'mse']. Default: 'mse'
prior_distribution (str): The distribution to use as prior
["wrapped_normal", "riemannian_normal"]. Default: "wrapped_normal"
posterior_distribution (str): The distribution to use as posterior
Expand Down
20 changes: 20 additions & 0 deletions src/pythae/models/pvae/pvae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,14 @@ def loss_function(self, recon_x, x, z, qz_x):
reduction="none",
).sum(dim=-1)

elif self.model_config.reconstruction_loss == "l1":

recon_loss = F.l1_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
).sum(dim=-1)

pz = self.prior(
loc=self._pz_mu, scale=self._pz_logvar.exp(), manifold=self.latent_manifold
)
Expand Down Expand Up @@ -278,6 +286,18 @@ def get_nll(self, data, n_samples=1, batch_size=100):
reduction="none",
).sum(dim=-1)

elif self.model_config.reconstruction_loss == "l1":

log_p_x_given_z = -0.5 * F.l1_loss(
recon_x.reshape(x_rep.shape[0], -1),
x_rep.reshape(x_rep.shape[0], -1),
reduction="none",
).sum(dim=-1) - torch.tensor(
[np.prod(self.input_dim) / 2 * np.log(np.pi * 2)]
).to(
data.device
) # decoding distribution is assumed unit variance N(mu, I)

log_p_x.append(log_p_x_given_z + log_p_z - log_q_z_given_x)

log_p_x = torch.cat(log_p_x)
Expand Down
1 change: 1 addition & 0 deletions src/pythae/models/rhvae/rhvae_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class RHVAEConfig(VAEConfig):

Parameters:
latent_dim (int): The latent dimension used for the latent space. Default: 10
reconstruction_loss (str): The reconstruction loss to use ['bce', 'l1', 'mse']. Default: 'mse'
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

l1 cannot be allowed in this model (see next comment)

n_lf (int): The number of leapfrog steps to used in the integrator: Default: 3
eps_lf (int): The leapfrog stepsize. Default: 1e-3
beta_zero (int): The tempering factor in the Riemannian Hamiltonian Monte Carlo Sampler.
Expand Down
26 changes: 26 additions & 0 deletions src/pythae/models/rhvae/rhvae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,20 @@ def _log_p_x_given_z(self, recon_x, x):
reduction="none",
).sum(dim=-1)

elif self.model_config.reconstruction_loss == "l1":
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this model, this should correspond to a known distribution. The mse actually models a multivariate normal. In the case of L1 loss, there is not really an underlying distribution so we should not allow the L1 for this model.

# sigma is taken as I_D
recon_loss = (
-0.5
* F.l1_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
).sum(dim=-1)
)
-torch.log(torch.tensor([2 * np.pi]).to(x.device)) * np.prod(
self.input_dim
) / 2

return recon_loss

def _log_z(self, z):
Expand Down Expand Up @@ -572,6 +586,18 @@ def get_nll(self, data, n_samples=1, batch_size=100):
reduction="none",
).sum(dim=-1)

elif self.model_config.reconstruction_loss == "l1":
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The computation of the likelihood cannot be handled with the l1 loss since it does not correspond to a tractable distribution per say.


log_p_x_given_z = -0.5 * F.l1_loss(
recon_x.reshape(x_rep.shape[0], -1),
x_rep.reshape(x_rep.shape[0], -1),
reduction="none",
).sum(dim=-1) - torch.tensor(
[np.prod(self.input_dim) / 2 * np.log(np.pi * 2)]
).to(
data.device
) # decoding distribution is assumed unit variance N(mu, I)

log_p_x.append(
log_p_x_given_z
+ log_p_z
Expand Down
2 changes: 1 addition & 1 deletion src/pythae/models/svae/svae_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@ class SVAEConfig(VAEConfig):
Parameters:
input_dim (tuple): The input_data dimension.
latent_dim (int): The latent space dimension in which lives the hypersphere. Default: None.
reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse'
reconstruction_loss (str): The reconstruction loss to use ['bce', 'l1', 'mse']. Default: 'mse'
"""
20 changes: 20 additions & 0 deletions src/pythae/models/svae/svae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,14 @@ def loss_function(self, recon_x, x, loc, concentration, z):
reduction="none",
).sum(dim=-1)

elif self.model_config.reconstruction_loss == "l1":

recon_loss = F.l1_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
).sum(dim=-1)

KLD = self._compute_kl(m=loc.shape[-1], concentration=concentration)

return (recon_loss + KLD).mean(dim=0), recon_loss.mean(dim=0), KLD.mean(dim=0)
Expand Down Expand Up @@ -286,6 +294,18 @@ def get_nll(self, data, n_samples=1, batch_size=100):
reduction="none",
).sum(dim=-1)

elif self.model_config.reconstruction_loss == "l1":
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The computation of the likelihood cannot be handled with the l1 loss since it does not correspond to a tractable distribution per say.


log_p_x_given_z = -0.5 * F.l1_loss(
recon_x.reshape(x_rep.shape[0], -1),
x_rep.reshape(x_rep.shape[0], -1),
reduction="none",
).sum(dim=-1) - torch.tensor(
[np.prod(self.input_dim) / 2 * np.log(np.pi * 2)]
).to(
data.device
) # decoding distribution is assumed unit variance N(mu, I)

log_p_x.append(
log_p_x_given_z + log_p_z - log_q_z_given_x
) # log(2*pi) simplifies
Expand Down
2 changes: 1 addition & 1 deletion src/pythae/models/vae/vae_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class VAEConfig(BaseAEConfig):
Parameters:
input_dim (tuple): The input_data dimension.
latent_dim (int): The latent space dimension. Default: None.
reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse'
reconstruction_loss (str): The reconstruction loss to use ['bce', 'l1', 'mse']. Default: 'mse'
"""

reconstruction_loss: Literal["bce", "mse"] = "mse"
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be replaced by

 reconstruction_loss: Literal["bce", "mse", "l1"] = "mse"

Loading