From 905a9e2f223fb55879e93039b1a8c2c07fce9c0b Mon Sep 17 00:00:00 2001 From: Soumick Chatterjee Date: Tue, 7 Feb 2023 17:58:14 +0100 Subject: [PATCH 01/15] predict method in the base, hmogenising recon_loss --- src/pythae/models/base/base_model.py | 20 ++++++++++++++++ .../models/beta_tc_vae/beta_tc_vae_model.py | 2 +- src/pythae/models/beta_vae/beta_vae_model.py | 2 +- src/pythae/models/ciwae/ciwae_model.py | 2 +- .../disentangled_beta_vae_model.py | 2 +- src/pythae/models/info_vae/info_vae_model.py | 2 +- src/pythae/models/iwae/iwae_model.py | 2 +- src/pythae/models/miwae/miwae_model.py | 2 +- .../models/msssim_vae/msssim_vae_model.py | 2 +- src/pythae/models/piwae/piwae_model.py | 2 +- src/pythae/models/pvae/pvae_model.py | 2 +- src/pythae/models/svae/svae_model.py | 2 +- src/pythae/models/vae/vae_model.py | 2 +- src/pythae/models/vae_iaf/vae_iaf_model.py | 2 +- .../models/vae_lin_nf/vae_lin_nf_model.py | 2 +- src/pythae/models/vamp/vamp_model.py | 2 +- tests/test_AE.py | 23 +++++++++++++++++++ tests/test_BetaTCVAE.py | 2 +- tests/test_BetaVAE.py | 2 +- tests/test_CIWAE.py | 2 +- tests/test_DisentangledBetaVAE.py | 2 +- tests/test_IWAE.py | 2 +- tests/test_MIWAE.py | 2 +- tests/test_MSSSIMVAE.py | 2 +- tests/test_PIWAE.py | 2 +- tests/test_PoincareVAE.py | 2 +- tests/test_SVAE.py | 2 +- tests/test_VAE.py | 2 +- tests/test_VAE_IAF.py | 2 +- tests/test_VAE_LinFlow.py | 2 +- tests/test_VAMP.py | 2 +- tests/test_info_vae_mmd.py | 2 +- 32 files changed, 73 insertions(+), 30 deletions(-) diff --git a/src/pythae/models/base/base_model.py b/src/pythae/models/base/base_model.py index bf5d1b1e..50b52293 100644 --- a/src/pythae/models/base/base_model.py +++ b/src/pythae/models/base/base_model.py @@ -117,6 +117,26 @@ def reconstruct(self, inputs: torch.Tensor): """ return self(DatasetOutput(data=inputs)).recon_x + def predict(self, inputs: BaseDataset, **kwargs) -> ModelOutput: + """The input data is encoded and decoded without computing loss + Args: + inputs (BaseDataset): An instance of pythae's datasets + Returns: + ModelOutput: An instance of ModelOutput containing reconstruction and embedding + """ + + x = inputs["data"] + + z = self.encoder(x).embedding + recon_x = self.decoder(z)["reconstruction"] + + output = ModelOutput( + recon_x=recon_x, + embedding=z, + ) + + return output + def interpolate( self, starting_inputs: torch.Tensor, diff --git a/src/pythae/models/beta_tc_vae/beta_tc_vae_model.py b/src/pythae/models/beta_tc_vae/beta_tc_vae_model.py index 99576c28..cd8d76ac 100644 --- a/src/pythae/models/beta_tc_vae/beta_tc_vae_model.py +++ b/src/pythae/models/beta_tc_vae/beta_tc_vae_model.py @@ -79,7 +79,7 @@ def forward(self, inputs: BaseDataset, **kwargs): ) output = ModelOutput( - reconstruction_loss=recon_loss, + recon_loss=recon_loss, reg_loss=kld, loss=loss, recon_x=recon_x, diff --git a/src/pythae/models/beta_vae/beta_vae_model.py b/src/pythae/models/beta_vae/beta_vae_model.py index d05f3bd4..7e7ed480 100644 --- a/src/pythae/models/beta_vae/beta_vae_model.py +++ b/src/pythae/models/beta_vae/beta_vae_model.py @@ -71,7 +71,7 @@ def forward(self, inputs: BaseDataset, **kwargs): loss, recon_loss, kld = self.loss_function(recon_x, x, mu, log_var, z) output = ModelOutput( - reconstruction_loss=recon_loss, + recon_loss=recon_loss, reg_loss=kld, loss=loss, recon_x=recon_x, diff --git a/src/pythae/models/ciwae/ciwae_model.py b/src/pythae/models/ciwae/ciwae_model.py index efdbe936..6b8158a4 100644 --- a/src/pythae/models/ciwae/ciwae_model.py +++ b/src/pythae/models/ciwae/ciwae_model.py @@ -78,7 +78,7 @@ def forward(self, inputs: BaseDataset, **kwargs): loss, recon_loss, kld = self.loss_function(recon_x, x, mu, log_var, z) output = ModelOutput( - reconstruction_loss=recon_loss, + recon_loss=recon_loss, reg_loss=kld, loss=loss, recon_x=recon_x.reshape(x.shape[0], self.n_samples, -1)[:, 0, :].reshape_as( diff --git a/src/pythae/models/disentangled_beta_vae/disentangled_beta_vae_model.py b/src/pythae/models/disentangled_beta_vae/disentangled_beta_vae_model.py index 87bb2f00..7d935290 100644 --- a/src/pythae/models/disentangled_beta_vae/disentangled_beta_vae_model.py +++ b/src/pythae/models/disentangled_beta_vae/disentangled_beta_vae_model.py @@ -79,7 +79,7 @@ def forward(self, inputs: BaseDataset, **kwargs): loss, recon_loss, kld = self.loss_function(recon_x, x, mu, log_var, z, epoch) output = ModelOutput( - reconstruction_loss=recon_loss, + recon_loss=recon_loss, reg_loss=kld, loss=loss, recon_x=recon_x, diff --git a/src/pythae/models/info_vae/info_vae_model.py b/src/pythae/models/info_vae/info_vae_model.py index 95599b05..915b5c7a 100644 --- a/src/pythae/models/info_vae/info_vae_model.py +++ b/src/pythae/models/info_vae/info_vae_model.py @@ -77,7 +77,7 @@ def forward(self, inputs: BaseDataset, **kwargs) -> ModelOutput: output = ModelOutput( loss=loss, - reconstruction_loss=recon_loss, + recon_loss=recon_loss, reg_loss=kld_loss, mmd_loss=mmd_loss, recon_x=recon_x, diff --git a/src/pythae/models/iwae/iwae_model.py b/src/pythae/models/iwae/iwae_model.py index 6bef4bf9..7b03b981 100644 --- a/src/pythae/models/iwae/iwae_model.py +++ b/src/pythae/models/iwae/iwae_model.py @@ -77,7 +77,7 @@ def forward(self, inputs: BaseDataset, **kwargs): loss, recon_loss, kld = self.loss_function(recon_x, x, mu, log_var, z) output = ModelOutput( - reconstruction_loss=recon_loss, + recon_loss=recon_loss, reg_loss=kld, loss=loss, recon_x=recon_x.reshape(x.shape[0], self.n_samples, -1)[:, 0, :].reshape_as( diff --git a/src/pythae/models/miwae/miwae_model.py b/src/pythae/models/miwae/miwae_model.py index 9bc55280..e9549022 100644 --- a/src/pythae/models/miwae/miwae_model.py +++ b/src/pythae/models/miwae/miwae_model.py @@ -86,7 +86,7 @@ def forward(self, inputs: BaseDataset, **kwargs): loss, recon_loss, kld = self.loss_function(recon_x, x, mu, log_var, z) output = ModelOutput( - reconstruction_loss=recon_loss, + recon_loss=recon_loss, reg_loss=kld, loss=loss, recon_x=recon_x.reshape( diff --git a/src/pythae/models/msssim_vae/msssim_vae_model.py b/src/pythae/models/msssim_vae/msssim_vae_model.py index 686fcb8a..b942c867 100644 --- a/src/pythae/models/msssim_vae/msssim_vae_model.py +++ b/src/pythae/models/msssim_vae/msssim_vae_model.py @@ -72,7 +72,7 @@ def forward(self, inputs: BaseDataset, **kwargs): loss, recon_loss, kld = self.loss_function(recon_x, x, mu, log_var, z) output = ModelOutput( - reconstruction_loss=recon_loss, + recon_loss=recon_loss, reg_loss=kld, loss=loss, recon_x=recon_x, diff --git a/src/pythae/models/piwae/piwae_model.py b/src/pythae/models/piwae/piwae_model.py index 167c80ab..01dcaa11 100644 --- a/src/pythae/models/piwae/piwae_model.py +++ b/src/pythae/models/piwae/piwae_model.py @@ -90,7 +90,7 @@ def forward(self, inputs: BaseDataset, **kwargs): loss = miwae_loss + iwae_loss output = ModelOutput( - reconstruction_loss=recon_loss, + recon_loss=recon_loss, reg_loss=kld, loss=loss, encoder_loss=miwae_loss, diff --git a/src/pythae/models/pvae/pvae_model.py b/src/pythae/models/pvae/pvae_model.py index 189f8677..d15462e6 100644 --- a/src/pythae/models/pvae/pvae_model.py +++ b/src/pythae/models/pvae/pvae_model.py @@ -115,7 +115,7 @@ def forward(self, inputs: BaseDataset, **kwargs): loss, recon_loss, kld = self.loss_function(recon_x, x, z, qz_x) output = ModelOutput( - reconstruction_loss=recon_loss, + recon_loss=recon_loss, reg_loss=kld, loss=loss, recon_x=recon_x, diff --git a/src/pythae/models/svae/svae_model.py b/src/pythae/models/svae/svae_model.py index 2f7d9f0a..d902aba1 100644 --- a/src/pythae/models/svae/svae_model.py +++ b/src/pythae/models/svae/svae_model.py @@ -88,7 +88,7 @@ def forward(self, inputs: BaseDataset, **kwargs): loss, recon_loss, kld = self.loss_function(recon_x, x, loc, concentration, z) output = ModelOutput( - reconstruction_loss=recon_loss, + recon_loss=recon_loss, reg_loss=kld, loss=loss, recon_x=recon_x, diff --git a/src/pythae/models/vae/vae_model.py b/src/pythae/models/vae/vae_model.py index 0f277605..9d5a90b2 100644 --- a/src/pythae/models/vae/vae_model.py +++ b/src/pythae/models/vae/vae_model.py @@ -88,7 +88,7 @@ def forward(self, inputs: BaseDataset, **kwargs): loss, recon_loss, kld = self.loss_function(recon_x, x, mu, log_var, z) output = ModelOutput( - reconstruction_loss=recon_loss, + recon_loss=recon_loss, reg_loss=kld, loss=loss, recon_x=recon_x, diff --git a/src/pythae/models/vae_iaf/vae_iaf_model.py b/src/pythae/models/vae_iaf/vae_iaf_model.py index 37c646c8..9f762794 100644 --- a/src/pythae/models/vae_iaf/vae_iaf_model.py +++ b/src/pythae/models/vae_iaf/vae_iaf_model.py @@ -93,7 +93,7 @@ def forward(self, inputs: BaseDataset, **kwargs): ) output = ModelOutput( - reconstruction_loss=recon_loss, + recon_loss=recon_loss, reg_loss=kld, loss=loss, recon_x=recon_x, diff --git a/src/pythae/models/vae_lin_nf/vae_lin_nf_model.py b/src/pythae/models/vae_lin_nf/vae_lin_nf_model.py index e4cd5164..e8bca0f4 100644 --- a/src/pythae/models/vae_lin_nf/vae_lin_nf_model.py +++ b/src/pythae/models/vae_lin_nf/vae_lin_nf_model.py @@ -106,7 +106,7 @@ def forward(self, inputs: BaseDataset, **kwargs): ) output = ModelOutput( - reconstruction_loss=recon_loss, + recon_loss=recon_loss, reg_loss=kld, loss=loss, recon_x=recon_x, diff --git a/src/pythae/models/vamp/vamp_model.py b/src/pythae/models/vamp/vamp_model.py index 8c582fcc..cfd9161a 100644 --- a/src/pythae/models/vamp/vamp_model.py +++ b/src/pythae/models/vamp/vamp_model.py @@ -91,7 +91,7 @@ def forward(self, inputs: BaseDataset, **kwargs): loss, recon_loss, kld = self.loss_function(recon_x, x, mu, log_var, z, epoch) output = ModelOutput( - reconstruction_loss=recon_loss, + recon_loss=recon_loss, reg_loss=kld, loss=loss, recon_x=recon_x, diff --git a/tests/test_AE.py b/tests/test_AE.py index 3c12b653..d17aad7b 100644 --- a/tests/test_AE.py +++ b/tests/test_AE.py @@ -355,6 +355,29 @@ def test_reconstruct(self, ae, demo_data): recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape +class Test_Model_predict: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture() + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return AE(model_configs) + + def test_predict(self, ae, demo_data): + + model_output = ae.predict(demo_data) + assert tuple(model_output.recon_x.shape) == demo_data.shape + assert tuple(model_output.embedding.shape) == ae.model_config.latent_dim.shape @pytest.mark.slow class Test_AE_Training: diff --git a/tests/test_BetaTCVAE.py b/tests/test_BetaTCVAE.py index ea5f6829..396be3ee 100644 --- a/tests/test_BetaTCVAE.py +++ b/tests/test_BetaTCVAE.py @@ -298,7 +298,7 @@ def test_model_train_output(self, betavae, demo_data): assert isinstance(out, ModelOutput) - assert set(["reconstruction_loss", "reg_loss", "loss", "recon_x", "z"]) == set( + assert set(["recon_loss", "reg_loss", "loss", "recon_x", "z"]) == set( out.keys() ) diff --git a/tests/test_BetaVAE.py b/tests/test_BetaVAE.py index dd821559..051837cc 100644 --- a/tests/test_BetaVAE.py +++ b/tests/test_BetaVAE.py @@ -289,7 +289,7 @@ def test_model_train_output(self, betavae, demo_data): assert isinstance(out, ModelOutput) - assert set(["reconstruction_loss", "reg_loss", "loss", "recon_x", "z"]) == set( + assert set(["recon_loss", "reg_loss", "loss", "recon_x", "z"]) == set( out.keys() ) diff --git a/tests/test_CIWAE.py b/tests/test_CIWAE.py index 74df8809..d1b0a827 100644 --- a/tests/test_CIWAE.py +++ b/tests/test_CIWAE.py @@ -297,7 +297,7 @@ def test_model_train_output(self, CIWAE, demo_data): assert isinstance(out, ModelOutput) - assert set(["reconstruction_loss", "reg_loss", "loss", "recon_x", "z"]) == set( + assert set(["recon_loss", "reg_loss", "loss", "recon_x", "z"]) == set( out.keys() ) diff --git a/tests/test_DisentangledBetaVAE.py b/tests/test_DisentangledBetaVAE.py index f74618a9..8c567a18 100644 --- a/tests/test_DisentangledBetaVAE.py +++ b/tests/test_DisentangledBetaVAE.py @@ -308,7 +308,7 @@ def test_model_train_output(self, betavae, demo_data): assert isinstance(out, ModelOutput) - assert set(["reconstruction_loss", "reg_loss", "loss", "recon_x", "z"]) == set( + assert set(["recon_loss", "reg_loss", "loss", "recon_x", "z"]) == set( out.keys() ) diff --git a/tests/test_IWAE.py b/tests/test_IWAE.py index cfc0f716..57220f96 100644 --- a/tests/test_IWAE.py +++ b/tests/test_IWAE.py @@ -301,7 +301,7 @@ def test_model_train_output(self, iwae, demo_data): assert isinstance(out, ModelOutput) - assert set(["reconstruction_loss", "reg_loss", "loss", "recon_x", "z"]) == set( + assert set(["recon_loss", "reg_loss", "loss", "recon_x", "z"]) == set( out.keys() ) diff --git a/tests/test_MIWAE.py b/tests/test_MIWAE.py index 80424203..6030c029 100644 --- a/tests/test_MIWAE.py +++ b/tests/test_MIWAE.py @@ -296,7 +296,7 @@ def test_model_train_output(self, MIWAE, demo_data): assert isinstance(out, ModelOutput) - assert set(["reconstruction_loss", "reg_loss", "loss", "recon_x", "z"]) == set( + assert set(["recon_loss", "reg_loss", "loss", "recon_x", "z"]) == set( out.keys() ) diff --git a/tests/test_MSSSIMVAE.py b/tests/test_MSSSIMVAE.py index dda5d998..929370fa 100644 --- a/tests/test_MSSSIMVAE.py +++ b/tests/test_MSSSIMVAE.py @@ -299,7 +299,7 @@ def test_model_train_output(self, msssim_vae, demo_data): assert isinstance(out, ModelOutput) - assert set(["reconstruction_loss", "reg_loss", "loss", "recon_x", "z"]) == set( + assert set(["recon_loss", "reg_loss", "loss", "recon_x", "z"]) == set( out.keys() ) diff --git a/tests/test_PIWAE.py b/tests/test_PIWAE.py index 67346738..78c66928 100644 --- a/tests/test_PIWAE.py +++ b/tests/test_PIWAE.py @@ -296,7 +296,7 @@ def test_model_train_output(self, piwae, demo_data): set( [ "loss", - "reconstruction_loss", + "recon_loss", "encoder_loss", "decoder_loss", "update_encoder", diff --git a/tests/test_PoincareVAE.py b/tests/test_PoincareVAE.py index a292800d..e54d4694 100644 --- a/tests/test_PoincareVAE.py +++ b/tests/test_PoincareVAE.py @@ -334,7 +334,7 @@ def test_model_train_output(self, vae, demo_data): assert isinstance(out, ModelOutput) - assert set(["reconstruction_loss", "reg_loss", "loss", "recon_x", "z"]) == set( + assert set(["recon_loss", "reg_loss", "loss", "recon_x", "z"]) == set( out.keys() ) diff --git a/tests/test_SVAE.py b/tests/test_SVAE.py index e978b30a..5cd9f389 100644 --- a/tests/test_SVAE.py +++ b/tests/test_SVAE.py @@ -283,7 +283,7 @@ def test_model_train_output(self, svae, demo_data): assert isinstance(out, ModelOutput) - assert set(["reconstruction_loss", "reg_loss", "loss", "recon_x", "z"]) == set( + assert set(["recon_loss", "reg_loss", "loss", "recon_x", "z"]) == set( out.keys() ) diff --git a/tests/test_VAE.py b/tests/test_VAE.py index f914daf8..83cde490 100644 --- a/tests/test_VAE.py +++ b/tests/test_VAE.py @@ -289,7 +289,7 @@ def test_model_train_output(self, vae, demo_data): assert isinstance(out, ModelOutput) - assert set(["reconstruction_loss", "reg_loss", "loss", "recon_x", "z"]) == set( + assert set(["recon_loss", "reg_loss", "loss", "recon_x", "z"]) == set( out.keys() ) diff --git a/tests/test_VAE_IAF.py b/tests/test_VAE_IAF.py index 6e9b7274..60de1d97 100644 --- a/tests/test_VAE_IAF.py +++ b/tests/test_VAE_IAF.py @@ -295,7 +295,7 @@ def test_model_train_output(self, vae, demo_data): assert isinstance(out, ModelOutput) - assert set(["reconstruction_loss", "reg_loss", "loss", "recon_x", "z"]) == set( + assert set(["recon_loss", "reg_loss", "loss", "recon_x", "z"]) == set( out.keys() ) diff --git a/tests/test_VAE_LinFlow.py b/tests/test_VAE_LinFlow.py index 91d121b0..4dc3421b 100644 --- a/tests/test_VAE_LinFlow.py +++ b/tests/test_VAE_LinFlow.py @@ -309,7 +309,7 @@ def test_model_train_output(self, vae, demo_data): assert isinstance(out, ModelOutput) - assert set(["reconstruction_loss", "reg_loss", "loss", "recon_x", "z"]) == set( + assert set(["recon_loss", "reg_loss", "loss", "recon_x", "z"]) == set( out.keys() ) diff --git a/tests/test_VAMP.py b/tests/test_VAMP.py index 548f68aa..9408fd73 100644 --- a/tests/test_VAMP.py +++ b/tests/test_VAMP.py @@ -287,7 +287,7 @@ def test_model_train_output(self, vamp, demo_data): assert isinstance(out, ModelOutput) - assert set(["reconstruction_loss", "reg_loss", "loss", "recon_x", "z"]) == set( + assert set(["recon_loss", "reg_loss", "loss", "recon_x", "z"]) == set( out.keys() ) diff --git a/tests/test_info_vae_mmd.py b/tests/test_info_vae_mmd.py index 016d558e..1ab10266 100644 --- a/tests/test_info_vae_mmd.py +++ b/tests/test_info_vae_mmd.py @@ -301,7 +301,7 @@ def test_model_train_output(self, info_vae_mmd, demo_data): assert isinstance(out, ModelOutput) assert set( - ["reconstruction_loss", "reg_loss", "loss", "mmd_loss", "recon_x", "z"] + ["recon_loss", "reg_loss", "loss", "mmd_loss", "recon_x", "z"] ) == set(out.keys()) assert out.z.shape[0] == demo_data["data"].shape[0] From 1a8652cc3588b62e16a64d7da518494332664724 Mon Sep 17 00:00:00 2001 From: Soumick Chatterjee Date: Wed, 8 Feb 2023 00:19:13 +0100 Subject: [PATCH 02/15] minor modifications for PR 75 --- src/pythae/models/base/base_model.py | 11 +++++------ tests/test_AE.py | 3 +-- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/pythae/models/base/base_model.py b/src/pythae/models/base/base_model.py index 50b52293..aae43fa1 100644 --- a/src/pythae/models/base/base_model.py +++ b/src/pythae/models/base/base_model.py @@ -117,17 +117,16 @@ def reconstruct(self, inputs: torch.Tensor): """ return self(DatasetOutput(data=inputs)).recon_x - def predict(self, inputs: BaseDataset, **kwargs) -> ModelOutput: + def predict(self, inputs: torch.Tensor) -> ModelOutput: """The input data is encoded and decoded without computing loss + Args: - inputs (BaseDataset): An instance of pythae's datasets + inputs (torch.Tensor): The input data to be reconstructed, as well as to generate the embedding. + Returns: ModelOutput: An instance of ModelOutput containing reconstruction and embedding """ - - x = inputs["data"] - - z = self.encoder(x).embedding + z = self.encoder(inputs).embedding recon_x = self.decoder(z)["reconstruction"] output = ModelOutput( diff --git a/tests/test_AE.py b/tests/test_AE.py index d17aad7b..c9d7d4be 100644 --- a/tests/test_AE.py +++ b/tests/test_AE.py @@ -376,8 +376,7 @@ def ae(self, model_configs, demo_data): def test_predict(self, ae, demo_data): model_output = ae.predict(demo_data) - assert tuple(model_output.recon_x.shape) == demo_data.shape - assert tuple(model_output.embedding.shape) == ae.model_config.latent_dim.shape + assert tuple(model_output.embedding.shape) == (demo_data.shape[0], ae.model_config.latent_dim) @pytest.mark.slow class Test_AE_Training: From ac73bb36ca3996461ce3496e6b75679e411b40d8 Mon Sep 17 00:00:00 2001 From: Soumick Chatterjee Date: Wed, 1 Mar 2023 14:26:40 +0100 Subject: [PATCH 03/15] Update SVAE, making it similar to the original implementation the values are clamped before sqrt to avoid NaN during training, as done in the original implementation https://github.com/nicola-decao/s-vae-pytorch/blob/master/hyperspherical_vae/distributions/von_mises_fisher.py#L68 --- src/pythae/models/svae/svae_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/pythae/models/svae/svae_model.py b/src/pythae/models/svae/svae_model.py index d902aba1..90dc6d56 100644 --- a/src/pythae/models/svae/svae_model.py +++ b/src/pythae/models/svae/svae_model.py @@ -147,7 +147,8 @@ def _sample_von_mises(self, loc, concentration): w = self._acc_rej_steps(m=loc.shape[-1], k=concentration) - z = torch.cat((w, (1 - w ** 2).sqrt() * v), dim=-1) + w_ = torch.sqrt(torch.clamp(1 - (w ** 2), 1e-10)) + z = torch.cat((w, w_ * v), dim=-1) return self._householder_rotation(loc, z) From 472e3a39f241421d8b7ef25d97c49ebe9a864482 Mon Sep 17 00:00:00 2001 From: "Soumick Chatterjee, PhD" Date: Wed, 5 Apr 2023 10:30:11 +0000 Subject: [PATCH 04/15] L1 loss added to the models --- .../adversarial_ae/adversarial_ae_model.py | 8 ++++++ .../models/beta_tc_vae/beta_tc_vae_model.py | 8 ++++++ src/pythae/models/beta_vae/beta_vae_model.py | 8 ++++++ src/pythae/models/ciwae/ciwae_model.py | 10 +++++++ .../disentangled_beta_vae_model.py | 8 ++++++ .../models/factor_vae/factor_vae_model.py | 8 ++++++ src/pythae/models/info_vae/info_vae_model.py | 8 ++++++ src/pythae/models/iwae/iwae_model.py | 10 +++++++ src/pythae/models/miwae/miwae_model.py | 11 ++++++++ src/pythae/models/piwae/piwae_model.py | 11 ++++++++ src/pythae/models/pvae/pvae_model.py | 20 ++++++++++++++ src/pythae/models/rhvae/rhvae_model.py | 26 +++++++++++++++++++ src/pythae/models/svae/svae_model.py | 20 ++++++++++++++ src/pythae/models/vae/vae_model.py | 20 ++++++++++++++ src/pythae/models/vae_iaf/vae_iaf_model.py | 20 ++++++++++++++ .../models/vae_lin_nf/vae_lin_nf_model.py | 20 ++++++++++++++ src/pythae/models/vamp/vamp_model.py | 20 ++++++++++++++ tests/test_Adversarial_AE.py | 3 +++ tests/test_BetaTCVAE.py | 3 +++ tests/test_BetaVAE.py | 1 + tests/test_CIWAE.py | 5 ++++ tests/test_DisentangledBetaVAE.py | 3 +++ tests/test_FactorVAE.py | 3 +++ tests/test_IWAE.py | 5 ++++ tests/test_MIWAE.py | 5 ++++ tests/test_MSSSIMVAE.py | 3 +++ tests/test_PIWAE.py | 2 ++ tests/test_PoincareVAE.py | 8 ++++++ tests/test_RHVAE.py | 3 +++ tests/test_SVAE.py | 1 + tests/test_VAE.py | 1 + tests/test_VAEGAN.py | 1 + tests/test_VAE_IAF.py | 1 + tests/test_VAE_LinFlow.py | 6 +++++ tests/test_VAMP.py | 1 + tests/test_info_vae_mmd.py | 5 ++++ 36 files changed, 296 insertions(+) diff --git a/src/pythae/models/adversarial_ae/adversarial_ae_model.py b/src/pythae/models/adversarial_ae/adversarial_ae_model.py index bd9bfcbe..43da703e 100644 --- a/src/pythae/models/adversarial_ae/adversarial_ae_model.py +++ b/src/pythae/models/adversarial_ae/adversarial_ae_model.py @@ -168,6 +168,14 @@ def loss_function(self, recon_x, x, z, z_prior): reduction="none", ).sum(dim=-1) + elif self.model_config.reconstruction_loss == "l1": + + recon_loss = F.l1_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) + gen_adversarial_score = self.discriminator(z).embedding.flatten() prior_adversarial_score = self.discriminator(z_prior).embedding.flatten() diff --git a/src/pythae/models/beta_tc_vae/beta_tc_vae_model.py b/src/pythae/models/beta_tc_vae/beta_tc_vae_model.py index cd8d76ac..dc8a56bc 100644 --- a/src/pythae/models/beta_tc_vae/beta_tc_vae_model.py +++ b/src/pythae/models/beta_tc_vae/beta_tc_vae_model.py @@ -106,6 +106,14 @@ def loss_function(self, recon_x, x, mu, log_var, z, dataset_size): reduction="none", ).sum(dim=-1) + elif self.model_config.reconstruction_loss == "l1": + + recon_loss = F.l1_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) + log_q_z_given_x = self._compute_log_gauss_density(z, mu, log_var).sum( dim=-1 ) # [B] diff --git a/src/pythae/models/beta_vae/beta_vae_model.py b/src/pythae/models/beta_vae/beta_vae_model.py index 7e7ed480..b11a6521 100644 --- a/src/pythae/models/beta_vae/beta_vae_model.py +++ b/src/pythae/models/beta_vae/beta_vae_model.py @@ -98,6 +98,14 @@ def loss_function(self, recon_x, x, mu, log_var, z): reduction="none", ).sum(dim=-1) + elif self.model_config.reconstruction_loss == "l1": + + recon_loss = F.l1_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) + KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=-1) return ( diff --git a/src/pythae/models/ciwae/ciwae_model.py b/src/pythae/models/ciwae/ciwae_model.py index 6b8158a4..07891bf3 100644 --- a/src/pythae/models/ciwae/ciwae_model.py +++ b/src/pythae/models/ciwae/ciwae_model.py @@ -113,6 +113,16 @@ def loss_function(self, recon_x, x, mu, log_var, z): reduction="none", ).sum(dim=-1) + elif self.model_config.reconstruction_loss == "l1": + + recon_loss = F.l1_loss( + recon_x, + x.reshape(recon_x.shape[0], -1) + .unsqueeze(1) + .repeat(1, self.n_samples, 1), + reduction="none", + ).sum(dim=-1) + log_q_z = (-0.5 * (log_var + torch.pow(z - mu, 2) / log_var.exp())).sum(dim=-1) log_p_z = -0.5 * (z ** 2).sum(dim=-1) diff --git a/src/pythae/models/disentangled_beta_vae/disentangled_beta_vae_model.py b/src/pythae/models/disentangled_beta_vae/disentangled_beta_vae_model.py index 7d935290..dd6e26d9 100644 --- a/src/pythae/models/disentangled_beta_vae/disentangled_beta_vae_model.py +++ b/src/pythae/models/disentangled_beta_vae/disentangled_beta_vae_model.py @@ -106,6 +106,14 @@ def loss_function(self, recon_x, x, mu, log_var, z, epoch): reduction="none", ).sum(dim=-1) + elif self.model_config.reconstruction_loss == "l1": + + recon_loss = F.l1_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) + KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=-1) C_factor = min(epoch / (self.warmup_epoch + 1), 1) KLD_diff = torch.abs(KLD - self.C * C_factor) diff --git a/src/pythae/models/factor_vae/factor_vae_model.py b/src/pythae/models/factor_vae/factor_vae_model.py index e4b2b164..97c691eb 100644 --- a/src/pythae/models/factor_vae/factor_vae_model.py +++ b/src/pythae/models/factor_vae/factor_vae_model.py @@ -151,6 +151,14 @@ def loss_function(self, recon_x, x, mu, log_var, z, z_bis_permuted): reduction="none", ).sum(dim=-1) + elif self.model_config.reconstruction_loss == "l1": + + recon_loss = F.l1_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) + KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=-1) latent_adversarial_score = self.discriminator(z) diff --git a/src/pythae/models/info_vae/info_vae_model.py b/src/pythae/models/info_vae/info_vae_model.py index 915b5c7a..1e96191d 100644 --- a/src/pythae/models/info_vae/info_vae_model.py +++ b/src/pythae/models/info_vae/info_vae_model.py @@ -106,6 +106,14 @@ def loss_function(self, recon_x, x, z, z_prior, mu, log_var): reduction="none", ).sum(dim=-1) + elif self.model_config.reconstruction_loss == "l1": + + recon_loss = F.l1_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) + KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=-1) if self.kernel_choice == "rbf": diff --git a/src/pythae/models/iwae/iwae_model.py b/src/pythae/models/iwae/iwae_model.py index 7b03b981..6902c28f 100644 --- a/src/pythae/models/iwae/iwae_model.py +++ b/src/pythae/models/iwae/iwae_model.py @@ -112,6 +112,16 @@ def loss_function(self, recon_x, x, mu, log_var, z): reduction="none", ).sum(dim=-1) + elif self.model_config.reconstruction_loss == "l1": + + recon_loss = F.l1_loss( + recon_x, + x.reshape(recon_x.shape[0], -1) + .unsqueeze(1) + .repeat(1, self.n_samples, 1), + reduction="none", + ).sum(dim=-1) + log_q_z = (-0.5 * (log_var + torch.pow(z - mu, 2) / log_var.exp())).sum(dim=-1) log_p_z = -0.5 * (z ** 2).sum(dim=-1) diff --git a/src/pythae/models/miwae/miwae_model.py b/src/pythae/models/miwae/miwae_model.py index e9549022..c5ec1fc7 100644 --- a/src/pythae/models/miwae/miwae_model.py +++ b/src/pythae/models/miwae/miwae_model.py @@ -123,6 +123,17 @@ def loss_function(self, recon_x, x, mu, log_var, z): reduction="none", ).sum(dim=-1) + elif self.model_config.reconstruction_loss == "l1": + + recon_loss = F.l1_loss( + recon_x, + x.reshape(recon_x.shape[0], -1) + .unsqueeze(1) + .unsqueeze(1) + .repeat(1, self.gradient_n_estimates, self.n_samples, 1), + reduction="none", + ).sum(dim=-1) + log_q_z = (-0.5 * (log_var + torch.pow(z - mu, 2) / log_var.exp())).sum(dim=-1) log_p_z = -0.5 * (z ** 2).sum(dim=-1) diff --git a/src/pythae/models/piwae/piwae_model.py b/src/pythae/models/piwae/piwae_model.py index 01dcaa11..2d0d454b 100644 --- a/src/pythae/models/piwae/piwae_model.py +++ b/src/pythae/models/piwae/piwae_model.py @@ -131,6 +131,17 @@ def loss_function(self, recon_x, x, mu, log_var, z): reduction="none", ).sum(dim=-1) + elif self.model_config.reconstruction_loss == "l1": + + recon_loss = F.l1_loss( + recon_x, + x.reshape(recon_x.shape[0], -1) + .unsqueeze(1) + .unsqueeze(1) + .repeat(1, self.gradient_n_estimates, self.n_samples, 1), + reduction="none", + ).sum(dim=-1) + log_q_z = (-0.5 * (log_var + torch.pow(z - mu, 2) / log_var.exp())).sum(dim=-1) log_p_z = -0.5 * (z ** 2).sum(dim=-1) diff --git a/src/pythae/models/pvae/pvae_model.py b/src/pythae/models/pvae/pvae_model.py index d15462e6..03a215a6 100644 --- a/src/pythae/models/pvae/pvae_model.py +++ b/src/pythae/models/pvae/pvae_model.py @@ -142,6 +142,14 @@ def loss_function(self, recon_x, x, z, qz_x): reduction="none", ).sum(dim=-1) + elif self.model_config.reconstruction_loss == "l1": + + recon_loss = F.l1_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) + pz = self.prior( loc=self._pz_mu, scale=self._pz_logvar.exp(), manifold=self.latent_manifold ) @@ -278,6 +286,18 @@ def get_nll(self, data, n_samples=1, batch_size=100): reduction="none", ).sum(dim=-1) + elif self.model_config.reconstruction_loss == "l1": + + log_p_x_given_z = -0.5 * F.l1_loss( + recon_x.reshape(x_rep.shape[0], -1), + x_rep.reshape(x_rep.shape[0], -1), + reduction="none", + ).sum(dim=-1) - torch.tensor( + [np.prod(self.input_dim) / 2 * np.log(np.pi * 2)] + ).to( + data.device + ) # decoding distribution is assumed unit variance N(mu, I) + log_p_x.append(log_p_x_given_z + log_p_z - log_q_z_given_x) log_p_x = torch.cat(log_p_x) diff --git a/src/pythae/models/rhvae/rhvae_model.py b/src/pythae/models/rhvae/rhvae_model.py index 9b3cbfad..8ce22f7a 100644 --- a/src/pythae/models/rhvae/rhvae_model.py +++ b/src/pythae/models/rhvae/rhvae_model.py @@ -424,6 +424,20 @@ def _log_p_x_given_z(self, recon_x, x): reduction="none", ).sum(dim=-1) + elif self.model_config.reconstruction_loss == "l1": + # sigma is taken as I_D + recon_loss = ( + -0.5 + * F.l1_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) + ) + -torch.log(torch.tensor([2 * np.pi]).to(x.device)) * np.prod( + self.input_dim + ) / 2 + return recon_loss def _log_z(self, z): @@ -572,6 +586,18 @@ def get_nll(self, data, n_samples=1, batch_size=100): reduction="none", ).sum(dim=-1) + elif self.model_config.reconstruction_loss == "l1": + + log_p_x_given_z = -0.5 * F.l1_loss( + recon_x.reshape(x_rep.shape[0], -1), + x_rep.reshape(x_rep.shape[0], -1), + reduction="none", + ).sum(dim=-1) - torch.tensor( + [np.prod(self.input_dim) / 2 * np.log(np.pi * 2)] + ).to( + data.device + ) # decoding distribution is assumed unit variance N(mu, I) + log_p_x.append( log_p_x_given_z + log_p_z diff --git a/src/pythae/models/svae/svae_model.py b/src/pythae/models/svae/svae_model.py index 90dc6d56..d6247bc4 100644 --- a/src/pythae/models/svae/svae_model.py +++ b/src/pythae/models/svae/svae_model.py @@ -115,6 +115,14 @@ def loss_function(self, recon_x, x, loc, concentration, z): reduction="none", ).sum(dim=-1) + elif self.model_config.reconstruction_loss == "l1": + + recon_loss = F.l1_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) + KLD = self._compute_kl(m=loc.shape[-1], concentration=concentration) return (recon_loss + KLD).mean(dim=0), recon_loss.mean(dim=0), KLD.mean(dim=0) @@ -286,6 +294,18 @@ def get_nll(self, data, n_samples=1, batch_size=100): reduction="none", ).sum(dim=-1) + elif self.model_config.reconstruction_loss == "l1": + + log_p_x_given_z = -0.5 * F.l1_loss( + recon_x.reshape(x_rep.shape[0], -1), + x_rep.reshape(x_rep.shape[0], -1), + reduction="none", + ).sum(dim=-1) - torch.tensor( + [np.prod(self.input_dim) / 2 * np.log(np.pi * 2)] + ).to( + data.device + ) # decoding distribution is assumed unit variance N(mu, I) + log_p_x.append( log_p_x_given_z + log_p_z - log_q_z_given_x ) # log(2*pi) simplifies diff --git a/src/pythae/models/vae/vae_model.py b/src/pythae/models/vae/vae_model.py index 0f5df29b..72a91870 100644 --- a/src/pythae/models/vae/vae_model.py +++ b/src/pythae/models/vae/vae_model.py @@ -115,6 +115,14 @@ def loss_function(self, recon_x, x, mu, log_var, z): reduction="none", ).sum(dim=-1) + elif self.model_config.reconstruction_loss == "l1": + + recon_loss = F.l1_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) + KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=-1) return (recon_loss + KLD).mean(dim=0), recon_loss.mean(dim=0), KLD.mean(dim=0) @@ -186,6 +194,18 @@ def get_nll(self, data, n_samples=1, batch_size=100): reduction="none", ).sum(dim=-1) + elif self.model_config.reconstruction_loss == "l1": + + log_p_x_given_z = -0.5 * F.l1_loss( + recon_x.reshape(x_rep.shape[0], -1), + x_rep.reshape(x_rep.shape[0], -1), + reduction="none", + ).sum(dim=-1) - torch.tensor( + [np.prod(self.input_dim) / 2 * np.log(np.pi * 2)] + ).to( + data.device + ) # decoding distribution is assumed unit variance N(mu, I) + log_p_x.append( log_p_x_given_z + log_p_z - log_q_z_given_x ) # log(2*pi) simplifies diff --git a/src/pythae/models/vae_iaf/vae_iaf_model.py b/src/pythae/models/vae_iaf/vae_iaf_model.py index 9f762794..0f393b4a 100644 --- a/src/pythae/models/vae_iaf/vae_iaf_model.py +++ b/src/pythae/models/vae_iaf/vae_iaf_model.py @@ -120,6 +120,14 @@ def loss_function(self, recon_x, x, mu, log_var, z0, zk, log_abs_det_jac): reduction="none", ).sum(dim=-1) + elif self.model_config.reconstruction_loss == "l1": + + recon_loss = F.l1_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) + # starting gaussian log-density log_prob_z0 = ( -0.5 * (log_var + torch.pow(z0 - mu, 2) / torch.exp(log_var)) @@ -212,6 +220,18 @@ def get_nll(self, data, n_samples=1, batch_size=100): reduction="none", ).sum(dim=-1) + elif self.model_config.reconstruction_loss == "l1": + + log_p_x_given_z = -0.5 * F.l1_loss( + recon_x.reshape(x_rep.shape[0], -1), + x_rep.reshape(x_rep.shape[0], -1), + reduction="none", + ).sum(dim=-1) - torch.tensor( + [np.prod(self.input_dim) / 2 * np.log(np.pi * 2)] + ).to( + data.device + ) # decoding distribution is assumed unit variance N(mu, I) + log_p_x.append( log_p_x_given_z + log_p_z - log_q_z_given_x ) # log(2*pi) simplifies diff --git a/src/pythae/models/vae_lin_nf/vae_lin_nf_model.py b/src/pythae/models/vae_lin_nf/vae_lin_nf_model.py index e8bca0f4..5667697a 100644 --- a/src/pythae/models/vae_lin_nf/vae_lin_nf_model.py +++ b/src/pythae/models/vae_lin_nf/vae_lin_nf_model.py @@ -133,6 +133,14 @@ def loss_function(self, recon_x, x, mu, log_var, z0, zk, log_abs_det_jac): reduction="none", ).sum(dim=-1) + elif self.model_config.reconstruction_loss == "l1": + + recon_loss = F.l1_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) + # starting gaussian log-density log_prob_z0 = ( -0.5 * (log_var + torch.pow(z0 - mu, 2) / torch.exp(log_var)) @@ -226,6 +234,18 @@ def get_nll(self, data, n_samples=1, batch_size=100): reduction="none", ).sum(dim=-1) + elif self.model_config.reconstruction_loss == "l1": + + log_p_x_given_z = -0.5 * F.l1_loss( + recon_x.reshape(x_rep.shape[0], -1), + x_rep.reshape(x_rep.shape[0], -1), + reduction="none", + ).sum(dim=-1) - torch.tensor( + [np.prod(self.input_dim) / 2 * np.log(np.pi * 2)] + ).to( + data.device + ) # decoding distribution is assumed unit variance N(mu, I) + log_p_x.append( log_p_x_given_z + log_p_z - log_q_z_given_x ) # log(2*pi) simplifies diff --git a/src/pythae/models/vamp/vamp_model.py b/src/pythae/models/vamp/vamp_model.py index cfd9161a..571a610a 100644 --- a/src/pythae/models/vamp/vamp_model.py +++ b/src/pythae/models/vamp/vamp_model.py @@ -118,6 +118,14 @@ def loss_function(self, recon_x, x, mu, log_var, z, epoch): reduction="none", ).sum(dim=-1) + elif self.model_config.reconstruction_loss == "l1": + + recon_loss = F.l1_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) + log_p_z = self._log_p_z(z) log_q_z = (-0.5 * (log_var + torch.pow(z - mu, 2) / log_var.exp())).sum(dim=1) @@ -243,6 +251,18 @@ def get_nll(self, data, n_samples=1, batch_size=100): reduction="none", ).sum(dim=-1) + elif self.model_config.reconstruction_loss == "l1": + + log_p_x_given_z = -0.5 * F.l1_loss( + recon_x.reshape(x_rep.shape[0], -1), + x_rep.reshape(x_rep.shape[0], -1), + reduction="none", + ).sum(dim=-1) - torch.tensor( + [np.prod(self.input_dim) / 2 * np.log(np.pi * 2)] + ).to( + data.device + ) # decoding distribution is assumed unit variance N(mu, I) + log_p_x.append( log_p_x_given_z + log_p_z - log_q_z_given_x ) # log(2*pi) simplifies diff --git a/tests/test_Adversarial_AE.py b/tests/test_Adversarial_AE.py index 0265fa8e..44ab33f1 100644 --- a/tests/test_Adversarial_AE.py +++ b/tests/test_Adversarial_AE.py @@ -43,6 +43,9 @@ def model_configs_no_input_dim(request): Adversarial_AE_Config( input_dim=(1, 28, 28), latent_dim=10, reconstruction_loss="bce" ), + Adversarial_AE_Config( + input_dim=(1, 28, 28), latent_dim=10, reconstruction_loss="l1" + ), Adversarial_AE_Config(input_dim=(1, 2, 18), latent_dim=5), ] ) diff --git a/tests/test_BetaTCVAE.py b/tests/test_BetaTCVAE.py index 396be3ee..9725f992 100644 --- a/tests/test_BetaTCVAE.py +++ b/tests/test_BetaTCVAE.py @@ -40,6 +40,9 @@ def model_configs_no_input_dim(request): BetaTCVAEConfig( input_dim=(1, 28, 28), latent_dim=10, reconstruction_loss="bce" ), + BetaTCVAEConfig( + input_dim=(1, 28, 28), latent_dim=10, reconstruction_loss="l1" + ), BetaTCVAEConfig( input_dim=(1, 28), latent_dim=5, beta=5.2, alpha=10, gamma=2, use_mss=False ), diff --git a/tests/test_BetaVAE.py b/tests/test_BetaVAE.py index 051837cc..0b8c928c 100644 --- a/tests/test_BetaVAE.py +++ b/tests/test_BetaVAE.py @@ -33,6 +33,7 @@ def model_configs_no_input_dim(request): @pytest.fixture( params=[ BetaVAEConfig(input_dim=(1, 28, 28), latent_dim=10, reconstruction_loss="bce"), + BetaVAEConfig(input_dim=(1, 28, 28), latent_dim=10, reconstruction_loss="l1"), BetaVAEConfig(input_dim=(1, 28), latent_dim=5, beta=5.2), ] ) diff --git a/tests/test_CIWAE.py b/tests/test_CIWAE.py index d1b0a827..38e4b769 100644 --- a/tests/test_CIWAE.py +++ b/tests/test_CIWAE.py @@ -38,6 +38,11 @@ def model_configs_no_input_dim(request): reconstruction_loss="bce", number_samples=2, ), + CIWAEConfig( + input_dim=(1, 28, 28), + latent_dim=10, + reconstruction_loss="l1" + ), CIWAEConfig(input_dim=(1, 28), latent_dim=5), ] ) diff --git a/tests/test_DisentangledBetaVAE.py b/tests/test_DisentangledBetaVAE.py index 8c567a18..6cda8ad7 100644 --- a/tests/test_DisentangledBetaVAE.py +++ b/tests/test_DisentangledBetaVAE.py @@ -40,6 +40,9 @@ def model_configs_no_input_dim(request): DisentangledBetaVAEConfig( input_dim=(1, 28, 28), latent_dim=10, reconstruction_loss="bce" ), + DisentangledBetaVAEConfig( + input_dim=(1, 28, 28), latent_dim=10, reconstruction_loss="l1" + ), DisentangledBetaVAEConfig( input_dim=(1, 28), latent_dim=5, beta=5.2, warmup_epoch=0 ), diff --git a/tests/test_FactorVAE.py b/tests/test_FactorVAE.py index e8a99765..87abf104 100644 --- a/tests/test_FactorVAE.py +++ b/tests/test_FactorVAE.py @@ -42,6 +42,9 @@ def model_configs_no_input_dim(request): FactorVAEConfig( input_dim=(1, 28, 28), latent_dim=10, reconstruction_loss="bce" ), + FactorVAEConfig( + input_dim=(1, 28, 28), latent_dim=10, reconstruction_loss="l1" + ), FactorVAEConfig(input_dim=(1, 2, 18), latent_dim=5, gamma=10), ] ) diff --git a/tests/test_IWAE.py b/tests/test_IWAE.py index 57220f96..63e9d64d 100644 --- a/tests/test_IWAE.py +++ b/tests/test_IWAE.py @@ -45,6 +45,11 @@ def model_configs_no_input_dim(request): reconstruction_loss="bce", number_samples=2, ), + IWAEConfig( + input_dim=(1, 28, 28), + latent_dim=10, + reconstruction_loss="l1" + ), IWAEConfig(input_dim=(1, 28), latent_dim=5), ] ) diff --git a/tests/test_MIWAE.py b/tests/test_MIWAE.py index 6030c029..eb1776dd 100644 --- a/tests/test_MIWAE.py +++ b/tests/test_MIWAE.py @@ -40,6 +40,11 @@ def model_configs_no_input_dim(request): reconstruction_loss="bce", number_samples=2, ), + MIWAEConfig( + input_dim=(1, 28, 28), + latent_dim=10, + reconstruction_loss="l1" + ), MIWAEConfig(input_dim=(1, 28), latent_dim=5), ] ) diff --git a/tests/test_MSSSIMVAE.py b/tests/test_MSSSIMVAE.py index 929370fa..6226760c 100644 --- a/tests/test_MSSSIMVAE.py +++ b/tests/test_MSSSIMVAE.py @@ -37,6 +37,9 @@ def model_configs_no_input_dim(request): MSSSIM_VAEConfig( input_dim=(1, 28, 28), latent_dim=10, reconstruction_loss="bce" ), + MSSSIM_VAEConfig( + input_dim=(1, 28, 28), latent_dim=10, reconstruction_loss="l1" + ), MSSSIM_VAEConfig(input_dim=(1, 28), latent_dim=5, beta=5.2), ] ) diff --git a/tests/test_PIWAE.py b/tests/test_PIWAE.py index 78c66928..60fe7554 100644 --- a/tests/test_PIWAE.py +++ b/tests/test_PIWAE.py @@ -36,6 +36,8 @@ def model_configs_no_input_dim(request): @pytest.fixture( params=[ PIWAEConfig(input_dim=(1, 28, 28), latent_dim=10, number_gradient_estimates=3), + PIWAEConfig(input_dim=(1, 28, 28), latent_dim=10, number_gradient_estimates=3, reconstruction_loss="bce"), + PIWAEConfig(input_dim=(1, 28, 28), latent_dim=10, number_gradient_estimates=3, reconstruction_loss="l1"), PIWAEConfig(input_dim=(1, 2, 18), latent_dim=5), ] ) diff --git a/tests/test_PoincareVAE.py b/tests/test_PoincareVAE.py index e54d4694..71548472 100644 --- a/tests/test_PoincareVAE.py +++ b/tests/test_PoincareVAE.py @@ -53,6 +53,14 @@ def model_configs_no_input_dim(request): posterior_distribution="wrapped_normal", curvature=0.7, ), + PoincareVAEConfig( + input_dim=(1, 28, 28), + latent_dim=2, + reconstruction_loss="l1", + prior_distribution="wrapped_normal", + posterior_distribution="wrapped_normal", + curvature=0.7, + ), PoincareVAEConfig( input_dim=(1, 28), latent_dim=5, diff --git a/tests/test_RHVAE.py b/tests/test_RHVAE.py index 35809316..7463b78e 100644 --- a/tests/test_RHVAE.py +++ b/tests/test_RHVAE.py @@ -38,6 +38,9 @@ def model_configs_no_input_dim(request): RHVAEConfig( input_dim=(1, 28, 28), latent_dim=1, n_lf=1, reconstruction_loss="bce" ), + RHVAEConfig( + input_dim=(1, 28, 28), latent_dim=1, n_lf=1, reconstruction_loss="l1" + ), RHVAEConfig(input_dim=(1, 2, 18), latent_dim=2, n_lf=1), ] ) diff --git a/tests/test_SVAE.py b/tests/test_SVAE.py index 5cd9f389..e780eae4 100644 --- a/tests/test_SVAE.py +++ b/tests/test_SVAE.py @@ -27,6 +27,7 @@ def model_configs_no_input_dim(request): @pytest.fixture( params=[ SVAEConfig(input_dim=(1, 28, 28), latent_dim=10, reconstruction_loss="bce"), + SVAEConfig(input_dim=(1, 28, 28), latent_dim=10, reconstruction_loss="l1"), SVAEConfig(input_dim=(1, 28), latent_dim=5), ] ) diff --git a/tests/test_VAE.py b/tests/test_VAE.py index 83cde490..6b07e29b 100644 --- a/tests/test_VAE.py +++ b/tests/test_VAE.py @@ -33,6 +33,7 @@ def model_configs_no_input_dim(request): @pytest.fixture( params=[ VAEConfig(input_dim=(1, 28, 28), latent_dim=10, reconstruction_loss="bce"), + VAEConfig(input_dim=(1, 28, 28), latent_dim=10, reconstruction_loss="l1"), VAEConfig(input_dim=(1, 28), latent_dim=5), ] ) diff --git a/tests/test_VAEGAN.py b/tests/test_VAEGAN.py index 8cd2b2b4..7fccab36 100644 --- a/tests/test_VAEGAN.py +++ b/tests/test_VAEGAN.py @@ -42,6 +42,7 @@ def model_configs_no_input_dim(request): @pytest.fixture( params=[ VAEGANConfig(input_dim=(1, 28, 28), latent_dim=10, reconstruction_loss="bce"), + VAEGANConfig(input_dim=(1, 28, 28), latent_dim=10, reconstruction_loss="l1"), VAEGANConfig(input_dim=(1, 2, 18), latent_dim=5, reconstruction_layer=1), ] ) diff --git a/tests/test_VAE_IAF.py b/tests/test_VAE_IAF.py index 60de1d97..3042f072 100644 --- a/tests/test_VAE_IAF.py +++ b/tests/test_VAE_IAF.py @@ -33,6 +33,7 @@ def model_configs_no_input_dim(request): @pytest.fixture( params=[ VAE_IAF_Config(input_dim=(1, 28, 28), latent_dim=10, reconstruction_loss="bce"), + VAE_IAF_Config(input_dim=(1, 28, 28), latent_dim=10, reconstruction_loss="l1"), VAE_IAF_Config( input_dim=(1, 28), latent_dim=5, diff --git a/tests/test_VAE_LinFlow.py b/tests/test_VAE_LinFlow.py index 4dc3421b..1ff3828b 100644 --- a/tests/test_VAE_LinFlow.py +++ b/tests/test_VAE_LinFlow.py @@ -38,6 +38,12 @@ def model_configs_no_input_dim(request): reconstruction_loss="bce", flows=["Planar", "Radial", "Planar"], ), + VAE_LinNF_Config( + input_dim=(1, 28, 28), + latent_dim=10, + reconstruction_loss="l1", + flows=["Planar", "Radial", "Planar"], + ), VAE_LinNF_Config( input_dim=(1, 28), latent_dim=5, flows=["Radial", "Radial", "Radial"] ), diff --git a/tests/test_VAMP.py b/tests/test_VAMP.py index 9408fd73..aca89572 100644 --- a/tests/test_VAMP.py +++ b/tests/test_VAMP.py @@ -27,6 +27,7 @@ def model_configs_no_input_dim(request): @pytest.fixture( params=[ VAMPConfig(input_dim=(1, 28, 28), latent_dim=10, reconstruction_loss="bce"), + VAMPConfig(input_dim=(1, 28, 28), latent_dim=10, reconstruction_loss="l1"), VAMPConfig(input_dim=(1, 2, 18), number_components=10, latent_dim=5), ] ) diff --git a/tests/test_info_vae_mmd.py b/tests/test_info_vae_mmd.py index 1ab10266..c67f31a3 100644 --- a/tests/test_info_vae_mmd.py +++ b/tests/test_info_vae_mmd.py @@ -38,6 +38,11 @@ def model_configs_no_input_dim(request): reconstruction_loss="bce", kernel_choice="rbf", ), + INFOVAE_MMD_Config( + input_dim=(1, 28, 28), + latent_dim=10, + reconstruction_loss="l1" + ), INFOVAE_MMD_Config(input_dim=(1, 28), latent_dim=5, lbd=5.2, alpha=0.2), ] ) From 44bbd5ce5547eaeb5f0b5bd1431129406c8acc41 Mon Sep 17 00:00:00 2001 From: "Soumick Chatterjee, PhD" Date: Wed, 5 Apr 2023 10:54:26 +0000 Subject: [PATCH 05/15] comments for configs updated --- src/pythae/models/adversarial_ae/adversarial_ae_config.py | 2 +- src/pythae/models/beta_tc_vae/beta_tc_vae_config.py | 2 +- src/pythae/models/beta_vae/beta_vae_config.py | 2 +- src/pythae/models/ciwae/ciwae_config.py | 2 +- .../disentangled_beta_vae/disentangled_beta_vae_config.py | 2 +- src/pythae/models/factor_vae/factor_vae_config.py | 2 +- src/pythae/models/info_vae/info_vae_config.py | 2 +- src/pythae/models/iwae/iwae_config.py | 2 +- src/pythae/models/miwae/miwae_config.py | 2 +- src/pythae/models/piwae/piwae_config.py | 2 +- src/pythae/models/pvae/pvae_config.py | 2 +- src/pythae/models/rhvae/rhvae_config.py | 1 + src/pythae/models/svae/svae_config.py | 2 +- src/pythae/models/vae/vae_config.py | 2 +- src/pythae/models/vae_iaf/vae_iaf_config.py | 2 +- src/pythae/models/vae_lin_nf/vae_lin_nf_config.py | 2 +- src/pythae/models/vamp/vamp_config.py | 2 +- 17 files changed, 17 insertions(+), 16 deletions(-) diff --git a/src/pythae/models/adversarial_ae/adversarial_ae_config.py b/src/pythae/models/adversarial_ae/adversarial_ae_config.py index 88c704e5..fc336e5b 100644 --- a/src/pythae/models/adversarial_ae/adversarial_ae_config.py +++ b/src/pythae/models/adversarial_ae/adversarial_ae_config.py @@ -10,7 +10,7 @@ class Adversarial_AE_Config(VAEConfig): Parameters: input_dim (tuple): The input_data dimension. latent_dim (int): The latent space dimension. Default: None. - reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse' + reconstruction_loss (str): The reconstruction loss to use ['bce', 'l1', 'mse']. Default: 'mse' adversarial_loss_scale (float): Parameter scaling the adversarial loss. Default: 0.5 reconstruction_loss_scale (float): Parameter scaling the reconstruction loss. Default: 1 deterministic_posterior (bool): Whether to use a deterministic posterior (Dirac). Default: diff --git a/src/pythae/models/beta_tc_vae/beta_tc_vae_config.py b/src/pythae/models/beta_tc_vae/beta_tc_vae_config.py index aebdddbd..9a50e1a5 100644 --- a/src/pythae/models/beta_tc_vae/beta_tc_vae_config.py +++ b/src/pythae/models/beta_tc_vae/beta_tc_vae_config.py @@ -11,7 +11,7 @@ class BetaTCVAEConfig(VAEConfig): Parameters: input_dim (tuple): The input_data dimension. latent_dim (int): The latent space dimension. Default: None. - reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse' + reconstruction_loss (str): The reconstruction loss to use ['bce', 'l1', 'mse']. Default: 'mse' alpha (float): The balancing factor before the Index code Mutual Info. Default: 1 beta (float): The balancing factor before the Total Correlation. Default: 1 gamma (float): The balancing factor before the dimension-wise KL. Default: 1 diff --git a/src/pythae/models/beta_vae/beta_vae_config.py b/src/pythae/models/beta_vae/beta_vae_config.py index 5d2c8aed..a8c04983 100644 --- a/src/pythae/models/beta_vae/beta_vae_config.py +++ b/src/pythae/models/beta_vae/beta_vae_config.py @@ -11,7 +11,7 @@ class BetaVAEConfig(VAEConfig): Parameters: input_dim (tuple): The input_data dimension. latent_dim (int): The latent space dimension. Default: None. - reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse' + reconstruction_loss (str): The reconstruction loss to use ['bce', 'l1', 'mse']. Default: 'mse' beta (float): The balancing factor. Default: 1 """ diff --git a/src/pythae/models/ciwae/ciwae_config.py b/src/pythae/models/ciwae/ciwae_config.py index 99f9eb9d..337c4048 100644 --- a/src/pythae/models/ciwae/ciwae_config.py +++ b/src/pythae/models/ciwae/ciwae_config.py @@ -10,7 +10,7 @@ class CIWAEConfig(VAEConfig): Parameters: input_dim (tuple): The input_data dimension. latent_dim (int): The latent space dimension. Default: None. - reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse' + reconstruction_loss (str): The reconstruction loss to use ['bce', 'l1', 'mse']. Default: 'mse' number_samples (int): Number of samples to use on the Monte-Carlo estimation beta (float): The value of the factor in the convex combination of the VAE and IWAE ELBO. Default: 0.5. diff --git a/src/pythae/models/disentangled_beta_vae/disentangled_beta_vae_config.py b/src/pythae/models/disentangled_beta_vae/disentangled_beta_vae_config.py index 51f10d65..9c69ef83 100644 --- a/src/pythae/models/disentangled_beta_vae/disentangled_beta_vae_config.py +++ b/src/pythae/models/disentangled_beta_vae/disentangled_beta_vae_config.py @@ -11,7 +11,7 @@ class DisentangledBetaVAEConfig(VAEConfig): Parameters: input_dim (tuple): The input_data dimension. latent_dim (int): The latent space dimension. Default: None. - reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse' + reconstruction_loss (str): The reconstruction loss to use ['bce', 'l1', 'mse']. Default: 'mse' beta (float): The balancing factor. Default: 10. C (float): The value of the KL divergence term of the ELBO we wish to approach, measured in nats. Default: 50. diff --git a/src/pythae/models/factor_vae/factor_vae_config.py b/src/pythae/models/factor_vae/factor_vae_config.py index df445a70..5e09351a 100644 --- a/src/pythae/models/factor_vae/factor_vae_config.py +++ b/src/pythae/models/factor_vae/factor_vae_config.py @@ -11,7 +11,7 @@ class FactorVAEConfig(VAEConfig): Parameters: input_dim (tuple): The input_data dimension. latent_dim (int): The latent space dimension. Default: None. - reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse' + reconstruction_loss (str): The reconstruction loss to use ['bce', 'l1', 'mse']. Default: 'mse' gamma (float): The balancing factor before the Total Correlation. Default: 0.5 """ gamma: float = 2.0 diff --git a/src/pythae/models/info_vae/info_vae_config.py b/src/pythae/models/info_vae/info_vae_config.py index 0a00d135..7a470826 100644 --- a/src/pythae/models/info_vae/info_vae_config.py +++ b/src/pythae/models/info_vae/info_vae_config.py @@ -14,7 +14,7 @@ class INFOVAE_MMD_Config(VAEConfig): Parameters: input_dim (tuple): The input_data dimension. latent_dim (int): The latent space dimension. Default: None. - reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse' + reconstruction_loss (str): The reconstruction loss to use ['bce', 'l1', 'mse']. Default: 'mse' kernel_choice (str): The kernel to choose. Available options are ['rbf', 'imq'] i.e. radial basis functions or inverse multiquadratic kernel. Default: 'imq'. alpha (float): The alpha factor balancing the weigth: Default: 0.5 diff --git a/src/pythae/models/iwae/iwae_config.py b/src/pythae/models/iwae/iwae_config.py index 33dc44d8..749f4690 100644 --- a/src/pythae/models/iwae/iwae_config.py +++ b/src/pythae/models/iwae/iwae_config.py @@ -10,7 +10,7 @@ class IWAEConfig(VAEConfig): Parameters: input_dim (tuple): The input_data dimension. latent_dim (int): The latent space dimension. Default: None. - reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse' + reconstruction_loss (str): The reconstruction loss to use ['bce', 'l1', 'mse']. Default: 'mse' number_samples (int): Number of samples to use on the Monte-Carlo estimation. Default: 10 """ diff --git a/src/pythae/models/miwae/miwae_config.py b/src/pythae/models/miwae/miwae_config.py index 4e78049f..27852a07 100644 --- a/src/pythae/models/miwae/miwae_config.py +++ b/src/pythae/models/miwae/miwae_config.py @@ -10,7 +10,7 @@ class MIWAEConfig(VAEConfig): Parameters: input_dim (tuple): The input_data dimension. latent_dim (int): The latent space dimension. Default: None. - reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse' + reconstruction_loss (str): The reconstruction loss to use ['bce', 'l1', 'mse']. Default: 'mse' number_gradient_estimates (int): Number of (M-)estimates to use for the gradient estimate. Default: 5 number_samples (int): Number of samples to use on the Monte-Carlo estimation. Default: 10 diff --git a/src/pythae/models/piwae/piwae_config.py b/src/pythae/models/piwae/piwae_config.py index 26b56102..7f1b77d9 100644 --- a/src/pythae/models/piwae/piwae_config.py +++ b/src/pythae/models/piwae/piwae_config.py @@ -10,7 +10,7 @@ class PIWAEConfig(VAEConfig): Parameters: input_dim (tuple): The input_data dimension. latent_dim (int): The latent space dimension. Default: None. - reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse' + reconstruction_loss (str): The reconstruction loss to use ['bce', 'l1', 'mse']. Default: 'mse' number_gradient_estimates (int): Number of (M-)estimates to use for the gradient estimate for the encoder. Default: 5 number_samples (int): Number of samples to use on the Monte-Carlo estimation. Default: 10 diff --git a/src/pythae/models/pvae/pvae_config.py b/src/pythae/models/pvae/pvae_config.py index d72432bc..b82855e0 100644 --- a/src/pythae/models/pvae/pvae_config.py +++ b/src/pythae/models/pvae/pvae_config.py @@ -11,7 +11,7 @@ class PoincareVAEConfig(VAEConfig): Parameters: input_dim (tuple): The input_data dimension. latent_dim (int): The latent space dimension. Default: None. - reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse' + reconstruction_loss (str): The reconstruction loss to use ['bce', 'l1', 'mse']. Default: 'mse' prior_distribution (str): The distribution to use as prior ["wrapped_normal", "riemannian_normal"]. Default: "wrapped_normal" posterior_distribution (str): The distribution to use as posterior diff --git a/src/pythae/models/rhvae/rhvae_config.py b/src/pythae/models/rhvae/rhvae_config.py index 1a7ce0e5..d19cc0be 100644 --- a/src/pythae/models/rhvae/rhvae_config.py +++ b/src/pythae/models/rhvae/rhvae_config.py @@ -9,6 +9,7 @@ class RHVAEConfig(VAEConfig): Parameters: latent_dim (int): The latent dimension used for the latent space. Default: 10 + reconstruction_loss (str): The reconstruction loss to use ['bce', 'l1', 'mse']. Default: 'mse' n_lf (int): The number of leapfrog steps to used in the integrator: Default: 3 eps_lf (int): The leapfrog stepsize. Default: 1e-3 beta_zero (int): The tempering factor in the Riemannian Hamiltonian Monte Carlo Sampler. diff --git a/src/pythae/models/svae/svae_config.py b/src/pythae/models/svae/svae_config.py index 612e0a8b..e92b7636 100644 --- a/src/pythae/models/svae/svae_config.py +++ b/src/pythae/models/svae/svae_config.py @@ -11,5 +11,5 @@ class SVAEConfig(VAEConfig): Parameters: input_dim (tuple): The input_data dimension. latent_dim (int): The latent space dimension in which lives the hypersphere. Default: None. - reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse' + reconstruction_loss (str): The reconstruction loss to use ['bce', 'l1', 'mse']. Default: 'mse' """ diff --git a/src/pythae/models/vae/vae_config.py b/src/pythae/models/vae/vae_config.py index 190f97c5..ee904d48 100644 --- a/src/pythae/models/vae/vae_config.py +++ b/src/pythae/models/vae/vae_config.py @@ -11,7 +11,7 @@ class VAEConfig(BaseAEConfig): Parameters: input_dim (tuple): The input_data dimension. latent_dim (int): The latent space dimension. Default: None. - reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse' + reconstruction_loss (str): The reconstruction loss to use ['bce', 'l1', 'mse']. Default: 'mse' """ reconstruction_loss: Literal["bce", "mse"] = "mse" diff --git a/src/pythae/models/vae_iaf/vae_iaf_config.py b/src/pythae/models/vae_iaf/vae_iaf_config.py index 56046741..6a5ab20b 100644 --- a/src/pythae/models/vae_iaf/vae_iaf_config.py +++ b/src/pythae/models/vae_iaf/vae_iaf_config.py @@ -10,7 +10,7 @@ class VAE_IAF_Config(VAEConfig): Parameters: input_dim (int): The input_data dimension latent_dim (int): The latent space dimension. Default: None. - reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse'. + reconstruction_loss (str): The reconstruction loss to use ['bce', 'l1', 'mse']. Default: 'mse'. n_made_blocks (int): The number of :class:`~pythae.models.normalizing_flows.MADE` models to consider in the IAF used in the VAE. Default: 2. n_hidden_in_made (int): The number of hidden layers in the diff --git a/src/pythae/models/vae_lin_nf/vae_lin_nf_config.py b/src/pythae/models/vae_lin_nf/vae_lin_nf_config.py index f8129132..c629ae7c 100644 --- a/src/pythae/models/vae_lin_nf/vae_lin_nf_config.py +++ b/src/pythae/models/vae_lin_nf/vae_lin_nf_config.py @@ -13,7 +13,7 @@ class VAE_LinNF_Config(VAEConfig): Parameters: input_dim (int): The input_data dimension latent_dim (int): The latent space dimension. Default: None. - reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse' + reconstruction_loss (str): The reconstruction loss to use ['bce', 'l1', 'mse']. Default: 'mse' flows (List[str]): A list of strings corresponding to the class of each flow to be applied. Default: ['Plannar', 'Planar']. Flow choices: ['Planar', 'Radial']. """ diff --git a/src/pythae/models/vamp/vamp_config.py b/src/pythae/models/vamp/vamp_config.py index db7374c9..1a720324 100644 --- a/src/pythae/models/vamp/vamp_config.py +++ b/src/pythae/models/vamp/vamp_config.py @@ -10,7 +10,7 @@ class VAMPConfig(VAEConfig): Parameters: input_dim (tuple): The input_data dimension. latent_dim (int): The latent space dimension. Default: None. - reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse' + reconstruction_loss (str): The reconstruction loss to use ['bce', 'l1', 'mse']. Default: 'mse' number_components (int): The number of components to use in the VAMP prior. Default: 50 linear_scheduling_steps (int): The number of warmup steps to perform using a linear scheduling. Default: 0 From 6ad1c0af8732d384e92c3f2e143732b6b9f1e295 Mon Sep 17 00:00:00 2001 From: "soumick.chatterjee" Date: Thu, 6 Apr 2023 14:40:24 +0200 Subject: [PATCH 06/15] custom loss function added to factor VAE --- src/pythae/models/factor_vae/factor_vae_model.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/pythae/models/factor_vae/factor_vae_model.py b/src/pythae/models/factor_vae/factor_vae_model.py index 97c691eb..805df801 100644 --- a/src/pythae/models/factor_vae/factor_vae_model.py +++ b/src/pythae/models/factor_vae/factor_vae_model.py @@ -1,5 +1,5 @@ import os -from typing import Optional +from typing import Optional, Union, Callable import torch import torch.nn.functional as F @@ -31,6 +31,11 @@ class FactorVAE(VAE): architectures if desired. If None is provided, a simple Multi Layer Preception (https://en.wikipedia.org/wiki/Multilayer_perceptron) is used. Default: None. + custom_recon_loss_func: (torch.nn.Module or Callable): A custom loss function for calculation the reconstruction loss. + This is only used when the `reconstruction_loss` parameter in `model_config` is set to `custom`. This can be either + an instance of `torch.nn.Module` or a callable function. In either case, the function must take the following arguments: + - `recon_x`: The reconstructed data + - `x`: The original data. Default: None. .. note:: For high dimensional data we advice you to provide you own network architectures. With the @@ -42,6 +47,7 @@ def __init__( model_config: FactorVAEConfig, encoder: Optional[BaseEncoder] = None, decoder: Optional[BaseDecoder] = None, + custom_recon_loss_func: Optional[Union[torch.nn.Module, Callable]] = None, ): VAE.__init__(self, model_config=model_config, encoder=encoder, decoder=decoder) @@ -50,6 +56,7 @@ def __init__( self.model_name = "FactorVAE" self.gamma = model_config.gamma + self.custom_recon_loss_func = custom_recon_loss_func def set_discriminator(self, discriminator: BaseDiscriminator) -> None: r"""This method is called to set the discriminator network @@ -159,6 +166,10 @@ def loss_function(self, recon_x, x, mu, log_var, z, z_bis_permuted): reduction="none", ).sum(dim=-1) + elif self.model_config.reconstruction_loss == "custom": + + recon_loss = self.custom_recon_loss_func(recon_x, x) + KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=-1) latent_adversarial_score = self.discriminator(z) @@ -278,4 +289,4 @@ def _permute_dims(self, z): perms = torch.randperm(z.shape[0]).to(z.device) permuted[:, i] = z[perms, i] - return permuted + return permuted \ No newline at end of file From 518f64abe096dd71fabd761218e007169c09e9f9 Mon Sep 17 00:00:00 2001 From: Soumick Chatterjee Date: Tue, 23 May 2023 12:30:28 +0200 Subject: [PATCH 07/15] boh --- src/pythae/models/__init__.py | 0 .../models/factor_vae/factor_vae_model.py | 25 ++++++++++++++++--- 2 files changed, 21 insertions(+), 4 deletions(-) mode change 100755 => 100644 src/pythae/models/__init__.py diff --git a/src/pythae/models/__init__.py b/src/pythae/models/__init__.py old mode 100755 new mode 100644 diff --git a/src/pythae/models/factor_vae/factor_vae_model.py b/src/pythae/models/factor_vae/factor_vae_model.py index 805df801..c9aa8e01 100644 --- a/src/pythae/models/factor_vae/factor_vae_model.py +++ b/src/pythae/models/factor_vae/factor_vae_model.py @@ -99,6 +99,13 @@ def forward(self, inputs: BaseDataset, **kwargs) -> ModelOutput: # first batch x = inputs["data"][idx_1] + if self.model_config.reconstruction_loss == "custom_masked": + if "mask" not in inputs.keys(): + raise ValueError( + "No mask not present in the input for `custom_masked` reconstruction loss" + ) + mask = inputs["mask"][idx_1] + encoder_output = self.encoder(x) mu, log_var = encoder_output.embedding, encoder_output.log_covariance @@ -119,9 +126,15 @@ def forward(self, inputs: BaseDataset, **kwargs) -> ModelOutput: z_bis_permuted = self._permute_dims(z_bis).detach() - recon_loss, autoencoder_loss, discriminator_loss = self.loss_function( - recon_x, x, mu, log_var, z, z_bis_permuted - ) + if not self.model_config.reconstruction_loss == "custom_masked": + recon_loss, autoencoder_loss, discriminator_loss = self.loss_function( + recon_x, x, mu, log_var, z, z_bis_permuted + ) + else: + recon_loss, autoencoder_loss, discriminator_loss = self.loss_function( + recon_x, x, mu, log_var, z, z_bis_permuted, mask + ) + loss = autoencoder_loss + discriminator_loss @@ -138,7 +151,7 @@ def forward(self, inputs: BaseDataset, **kwargs) -> ModelOutput: return output - def loss_function(self, recon_x, x, mu, log_var, z, z_bis_permuted): + def loss_function(self, recon_x, x, mu, log_var, z, z_bis_permuted, mask=None): N = z.shape[0] # batch size @@ -170,6 +183,10 @@ def loss_function(self, recon_x, x, mu, log_var, z, z_bis_permuted): recon_loss = self.custom_recon_loss_func(recon_x, x) + elif self.model_config.reconstruction_loss == "custom_masked": + + recon_loss = self.custom_recon_loss_func(recon_x, x, mask) + KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=-1) latent_adversarial_score = self.discriminator(z) From e01360dcb08ac8e45c655a40892688524a9796f2 Mon Sep 17 00:00:00 2001 From: "soumick.chatterjee" Date: Tue, 23 May 2023 12:31:41 +0200 Subject: [PATCH 08/15] boh --- .coveragerc | 0 .github/ISSUE_TEMPLATE/bug_report.md | 0 .github/ISSUE_TEMPLATE/feature_request.md | 0 .github/workflows/coverage.yml | 0 .github/workflows/tests_bench.yml | 0 .gitignore | 0 .readthedocs.yml | 0 CITATION.cff | 0 CONTRIBUTING.md | 0 LICENSE | 0 MANIFEST.in | 0 README.md | 0 codecov.yml | 0 docs/Makefile | 0 docs/README.md | 0 docs/make.bat | 0 docs/requirements.txt | 0 docs/source/_static/basic.css | 0 docs/source/_templates/_variables.scss | 0 docs/source/_templates/layout.html | 0 docs/source/conf.py | 0 docs/source/imgs/CNNs.jpeg | Bin docs/source/imgs/Case_study_1.jpg | Bin docs/source/imgs/DA_diagram(1).pdf | Bin docs/source/imgs/DA_diagram.png | Bin docs/source/imgs/baseline_results.png | Bin docs/source/imgs/logo_pyraug_2.jpeg | Bin docs/source/imgs/nine_digits-rot.pdf | Bin docs/source/imgs/nine_digits-rot.png | Bin docs/source/imgs/nine_digits.pdf | Bin docs/source/imgs/nine_digits.png | Bin docs/source/imgs/optimized_results.png | Bin docs/source/imgs/pyraug_diagram.jpg | Bin docs/source/imgs/pyraug_diagram_simplified.jpg | Bin docs/source/index.rst | 0 docs/source/models/autoencoders/aae.rst | 0 docs/source/models/autoencoders/ae.rst | 0 docs/source/models/autoencoders/auto_model.rst | 0 docs/source/models/autoencoders/baseAE.rst | 0 docs/source/models/autoencoders/betatcvae.rst | 0 docs/source/models/autoencoders/betavae.rst | 0 docs/source/models/autoencoders/ciwae.rst | 0 .../models/autoencoders/disentangled_betavae.rst | 0 docs/source/models/autoencoders/factorvae.rst | 0 docs/source/models/autoencoders/hvae.rst | 0 docs/source/models/autoencoders/infovae.rst | 0 docs/source/models/autoencoders/iwae.rst | 0 docs/source/models/autoencoders/miwae.rst | 0 docs/source/models/autoencoders/models.rst | 0 docs/source/models/autoencoders/msssimvae.rst | 0 docs/source/models/autoencoders/piwae.rst | 0 docs/source/models/autoencoders/pvae.rst | 0 docs/source/models/autoencoders/rae_gp.rst | 0 docs/source/models/autoencoders/rae_l2.rst | 0 docs/source/models/autoencoders/rhvae.rst | 0 docs/source/models/autoencoders/svae.rst | 0 docs/source/models/autoencoders/vae.rst | 0 docs/source/models/autoencoders/vae_iaf.rst | 0 docs/source/models/autoencoders/vae_lin_nf.rst | 0 docs/source/models/autoencoders/vaegan.rst | 0 docs/source/models/autoencoders/vamp.rst | 0 docs/source/models/autoencoders/vqvae.rst | 0 docs/source/models/autoencoders/wae.rst | 0 docs/source/models/nn/celeba/convnets.rst | 0 .../nn/celeba/pythae_benchmarks_nn_celeba.rst | 0 docs/source/models/nn/celeba/resnets.rst | 0 docs/source/models/nn/cifar/convnets.rst | 0 .../models/nn/cifar/pythae_benchmarks_nn_cifar.rst | 0 docs/source/models/nn/cifar/resnets.rst | 0 docs/source/models/nn/mnist/convnets.rst | 0 .../models/nn/mnist/pythae_benchmarks_nn_mnist.rst | 0 docs/source/models/nn/mnist/resnets.rst | 0 docs/source/models/nn/nn.rst | 0 docs/source/models/nn/pythae_base_nn.rst | 0 docs/source/models/nn/pythae_benchmarks_nn.rst | 0 docs/source/models/normalizing_flows/basenf.rst | 0 docs/source/models/normalizing_flows/iaf.rst | 0 docs/source/models/normalizing_flows/made.rst | 0 docs/source/models/normalizing_flows/maf.rst | 0 .../models/normalizing_flows/normalizing_flows.rst | 0 docs/source/models/normalizing_flows/pixelcnn.rst | 0 .../source/models/normalizing_flows/planar_flow.rst | 0 .../source/models/normalizing_flows/radial_flow.rst | 0 docs/source/models/pythae.models.rst | 0 docs/source/pipelines/generation.rst | 0 docs/source/pipelines/pythae.pipelines.rst | 0 docs/source/pipelines/training.rst | 0 docs/source/references.bib | 0 docs/source/samplers/basesampler.rst | 0 docs/source/samplers/gmm_sampler.rst | 0 docs/source/samplers/iaf_sampler.rst | 0 docs/source/samplers/maf_sampler.rst | 0 docs/source/samplers/normal_sampler.rst | 0 docs/source/samplers/pixelcnn_sampler.rst | 0 docs/source/samplers/poincare_disk_sampler.rst | 0 docs/source/samplers/pythae.samplers.rst | 0 docs/source/samplers/rhvae_sampler.rst | 0 docs/source/samplers/twostage_sampler.rst | 0 docs/source/samplers/unit_sphere_unif_sampler.rst | 0 docs/source/samplers/vamp_sampler.rst | 0 docs/source/trainers/pythae.trainer.rst | 0 .../pythae.trainers.adversarial_trainer.rst | 0 .../trainers/pythae.trainers.base_trainer.rst | 0 ...ainers.coupled_optimizer_adversarial_trainer.rst | 0 .../trainers/pythae.trainers.coupled_trainer.rst | 0 docs/source/trainers/pythae.training_callbacks.rst | 0 examples/notebooks/README.md | 0 .../notebooks/comet_experiment_monitoring.ipynb | 0 examples/notebooks/custom_dataset.ipynb | 0 examples/notebooks/hf_hub_models_sharing.ipynb | 0 .../notebooks/making_your_own_autoencoder.ipynb | 0 .../notebooks/mlflow_experiment_monitoring.ipynb | 0 .../models_training/adversarial_ae_training.ipynb | 0 .../notebooks/models_training/ae_training.ipynb | 0 .../models_training/beta_tc_vae_training.ipynb | 0 .../models_training/beta_vae_training.ipynb | 0 .../notebooks/models_training/ciwae_training.ipynb | 0 .../disentangled_beta_vae_training.ipynb | 0 .../models_training/factor_vae_training.ipynb | 0 .../notebooks/models_training/hvae_training.ipynb | 0 .../models_training/info_vae_training.ipynb | 0 .../notebooks/models_training/iwae_training.ipynb | 0 .../notebooks/models_training/miwae_training.ipynb | 0 .../models_training/ms_ssim_vae_training.ipynb | 0 .../normalizing_flows_training.ipynb | 0 .../notebooks/models_training/piwae_training.ipynb | 0 .../notebooks/models_training/pvae_training.ipynb | 0 .../notebooks/models_training/rae_gp_training.ipynb | 0 .../notebooks/models_training/rae_l2_training.ipynb | 0 .../notebooks/models_training/rhvae_training.ipynb | 0 .../notebooks/models_training/svae_training.ipynb | 0 .../models_training/vae_iaf_training.ipynb | 0 .../models_training/vae_lin_nf_training.ipynb | 0 examples/notebooks/models_training/vae_lstm.ipynb | 0 .../notebooks/models_training/vae_training.ipynb | 0 .../notebooks/models_training/vaegan_training.ipynb | 0 .../notebooks/models_training/vamp_training.ipynb | 0 .../notebooks/models_training/vqvae_training.ipynb | 0 .../notebooks/models_training/wae_training.ipynb | 0 examples/notebooks/overview_notebook.ipynb | 0 examples/notebooks/requirements.txt | 0 .../notebooks/wandb_experiment_monitoring.ipynb | 0 examples/scripts/README.md | 0 .../configs/binary_mnist/base_training_config.json | 0 .../scripts/configs/binary_mnist/hvae_config.json | 0 .../scripts/configs/binary_mnist/iwae_config.json | 0 .../scripts/configs/binary_mnist/rhvae_config.json | 0 .../scripts/configs/binary_mnist/svae_config.json | 0 .../scripts/configs/binary_mnist/vae_config.json | 0 .../configs/binary_mnist/vae_iaf_config.json | 0 .../configs/binary_mnist/vae_lin_nf_config.json | 0 .../scripts/configs/binary_mnist/vamp_config.json | 0 examples/scripts/configs/celeba/aae_config.json | 0 examples/scripts/configs/celeba/ae_config.json | 0 .../configs/celeba/base_training_config.json | 0 .../scripts/configs/celeba/beta_tc_vae_config.json | 0 .../scripts/configs/celeba/beta_vae_config.json | 0 .../celeba/disentangled_beta_vae_config.json | 0 .../scripts/configs/celeba/factor_vae_config.json | 0 examples/scripts/configs/celeba/hvae_config.json | 0 .../scripts/configs/celeba/info_vae_config.json | 0 .../scripts/configs/celeba/msssim_vae_config.json | 0 examples/scripts/configs/celeba/rae_gp_config.json | 0 examples/scripts/configs/celeba/rae_l2_config.json | 0 examples/scripts/configs/celeba/rhvae_config.json | 0 examples/scripts/configs/celeba/svae_config.json | 0 examples/scripts/configs/celeba/vae_config.json | 0 examples/scripts/configs/celeba/vae_iaf_config.json | 0 .../scripts/configs/celeba/vae_lin_nf_config.json | 0 examples/scripts/configs/celeba/vaegan_config.json | 0 examples/scripts/configs/celeba/vamp_config.json | 0 examples/scripts/configs/celeba/vqvae_config.json | 0 examples/scripts/configs/celeba/wae_config.json | 0 examples/scripts/configs/cifar10/ae_config.json | 0 .../configs/cifar10/base_training_config.json | 0 .../scripts/configs/cifar10/beta_vae_config.json | 0 examples/scripts/configs/cifar10/hvae_config.json | 0 .../scripts/configs/cifar10/info_vae_config.json | 0 examples/scripts/configs/cifar10/rae_gp_config.json | 0 examples/scripts/configs/cifar10/rae_l2_config.json | 0 examples/scripts/configs/cifar10/rhvae_config.json | 0 examples/scripts/configs/cifar10/vae_config.json | 0 examples/scripts/configs/cifar10/vamp_config.json | 0 examples/scripts/configs/cifar10/vqvae_config.json | 0 examples/scripts/configs/cifar10/wae_config.json | 0 .../dsprites/beta_tc_vae/base_training_config.json | 0 .../dsprites/beta_tc_vae/beta_tc_vae_config.json | 0 .../dsprites/factorvae/base_training_config.json | 0 .../dsprites/factorvae/factorvae_config.json | 0 examples/scripts/configs/mnist/aae_config.json | 0 examples/scripts/configs/mnist/ae_config.json | 0 .../scripts/configs/mnist/base_training_config.json | 0 .../scripts/configs/mnist/beta_tc_vae_config.json | 0 examples/scripts/configs/mnist/beta_vae_config.json | 0 .../configs/mnist/disentangled_beta_vae_config.json | 0 .../scripts/configs/mnist/factor_vae_config.json | 0 examples/scripts/configs/mnist/hvae_config.json | 0 examples/scripts/configs/mnist/info_vae_config.json | 0 examples/scripts/configs/mnist/iwae_config.json | 0 .../scripts/configs/mnist/msssim_vae_config.json | 0 examples/scripts/configs/mnist/rae_gp_config.json | 0 examples/scripts/configs/mnist/rae_l2_config.json | 0 examples/scripts/configs/mnist/rhvae_config.json | 0 examples/scripts/configs/mnist/svae_config.json | 0 examples/scripts/configs/mnist/vae_config.json | 0 examples/scripts/configs/mnist/vae_iaf_config.json | 0 .../scripts/configs/mnist/vae_lin_nf_config.json | 0 examples/scripts/configs/mnist/vaegan_config.json | 0 examples/scripts/configs/mnist/vamp_config.json | 0 examples/scripts/configs/mnist/vqvae_config.json | 0 examples/scripts/configs/mnist/wae_config.json | 0 examples/scripts/custom_nn.py | 0 examples/scripts/distributed_training_ffhq.py | 0 examples/scripts/distributed_training_imagenet.py | 0 examples/scripts/distributed_training_mnist.py | 0 examples/scripts/reproducibility/README.md | 0 examples/scripts/reproducibility/aae.py | 0 examples/scripts/reproducibility/betatcvae.py | 0 examples/scripts/reproducibility/ciwae.py | 0 examples/scripts/reproducibility/hvae.py | 0 examples/scripts/reproducibility/iwae.py | 0 examples/scripts/reproducibility/miwae.py | 0 examples/scripts/reproducibility/piwae.py | 0 examples/scripts/reproducibility/pvae.py | 0 examples/scripts/reproducibility/rae_gp.py | 0 examples/scripts/reproducibility/rae_l2.py | 0 examples/scripts/reproducibility/svae.py | 0 examples/scripts/reproducibility/vae.py | 0 examples/scripts/reproducibility/vamp.py | 0 examples/scripts/reproducibility/wae.py | 0 examples/scripts/training.py | 0 examples/showcases/aae_normal_sampling_celeba.png | Bin examples/showcases/aae_normal_sampling_mnist.png | Bin examples/showcases/aae_reconstruction_celeba.png | Bin examples/showcases/aae_reconstruction_mnist.png | Bin examples/showcases/ae_gmm_sampling_celeba.png | Bin examples/showcases/ae_gmm_sampling_mnist.png | Bin examples/showcases/ae_normal_sampling_celeba.png | Bin examples/showcases/ae_normal_sampling_mnist.png | Bin examples/showcases/ae_reconstruction_celeba.png | Bin examples/showcases/ae_reconstruction_mnist.png | Bin .../beta_tc_vae_normal_sampling_celeba.png | Bin .../showcases/beta_tc_vae_normal_sampling_mnist.png | Bin .../showcases/beta_tc_vae_reconstruction_celeba.png | Bin .../showcases/beta_tc_vae_reconstruction_mnist.png | Bin .../showcases/beta_vae_normal_sampling_celeba.png | Bin .../showcases/beta_vae_normal_sampling_mnist.png | Bin .../showcases/beta_vae_reconstruction_celeba.png | Bin .../showcases/beta_vae_reconstruction_mnist.png | Bin ...disentangled_beta_vae_normal_sampling_celeba.png | Bin .../disentangled_beta_vae_normal_sampling_mnist.png | Bin .../disentangled_beta_vae_reconstruction_celeba.png | Bin .../disentangled_beta_vae_reconstruction_mnist.png | Bin examples/showcases/eval_reconstruction_celeba.png | Bin examples/showcases/eval_reconstruction_mnist.png | Bin .../showcases/factor_vae_normal_sampling_celeba.png | Bin .../showcases/factor_vae_normal_sampling_mnist.png | Bin .../showcases/factor_vae_reconstruction_celeba.png | Bin .../showcases/factor_vae_reconstruction_mnist.png | Bin examples/showcases/hvae_gmm_sampling_mnist.png | Bin examples/showcases/hvae_normal_sampling_celeba.png | Bin examples/showcases/hvae_normal_sampling_mnist.png | Bin examples/showcases/hvae_reconstruction_celeba.png | Bin examples/showcases/hvae_reconstruction_mnist.png | Bin examples/showcases/hvvae_gmm_sampling_mnist.png | Bin examples/showcases/hvvae_normal_sampling_mnist.png | Bin .../showcases/infovae_normal_sampling_celeba.png | Bin .../showcases/infovae_normal_sampling_mnist.png | Bin .../showcases/infovae_reconstruction_celeba.png | Bin examples/showcases/infovae_reconstruction_mnist.png | Bin examples/showcases/iwae_gmm_sampling_mnist.png | Bin examples/showcases/iwae_normal_sampling_celeba.png | Bin examples/showcases/iwae_normal_sampling_mnist.png | Bin examples/showcases/iwae_reconstruction_celeba.png | Bin examples/showcases/iwae_reconstruction_mnist.png | Bin .../showcases/msssim_vae_normal_sampling_celeba.png | Bin .../showcases/msssim_vae_normal_sampling_mnist.png | Bin .../showcases/msssim_vae_reconstruction_celeba.png | Bin .../showcases/msssim_vae_reconstruction_mnist.png | Bin examples/showcases/rae_gp_gmm_sampling_celeba.png | Bin examples/showcases/rae_gp_gmm_sampling_mnist.png | Bin .../showcases/rae_gp_normal_sampling_celeba.png | Bin examples/showcases/rae_gp_normal_sampling_mnist.png | Bin examples/showcases/rae_gp_reconstruction_celeba.png | Bin examples/showcases/rae_gp_reconstruction_mnist.png | Bin examples/showcases/rae_l2_gmm_sampling_celeba.png | Bin examples/showcases/rae_l2_gmm_sampling_mnist.png | Bin .../showcases/rae_l2_normal_sampling_celeba.png | Bin examples/showcases/rae_l2_normal_sampling_mnist.png | Bin examples/showcases/rae_l2_reconstruction_celeba.png | Bin examples/showcases/rae_l2_reconstruction_mnist.png | Bin examples/showcases/rhvae_reconstruction_celeba.png | Bin examples/showcases/rhvae_reconstruction_mnist.png | Bin examples/showcases/rhvae_rhvae_sampling_celeba.png | Bin examples/showcases/rhvae_rhvae_sampling_mnist.png | Bin .../svae_hypersphere_uniform_sampling_celeba.png | Bin .../svae_hypersphere_uniform_sampling_mnist.png | Bin examples/showcases/svae_reconstruction_celeba.png | Bin examples/showcases/svae_reconstruction_mnist.png | Bin examples/showcases/vae_gmm_sampling_celeba.png | Bin examples/showcases/vae_gmm_sampling_mnist.png | Bin .../showcases/vae_iaf_normal_sampling_celeba.png | Bin .../showcases/vae_iaf_normal_sampling_mnist.png | Bin .../showcases/vae_iaf_reconstruction_celeba.png | Bin examples/showcases/vae_iaf_reconstruction_mnist.png | Bin .../showcases/vae_lin_nf_normal_sampling_celeba.png | Bin .../showcases/vae_lin_nf_normal_sampling_mnist.png | Bin .../showcases/vae_lin_nf_reconstruction_celeba.png | Bin .../showcases/vae_lin_nf_reconstruction_mnist.png | Bin examples/showcases/vae_maf_sampling_celeba.png | Bin examples/showcases/vae_maf_sampling_mnist.png | Bin examples/showcases/vae_normal_sampling_celeba.png | Bin examples/showcases/vae_normal_sampling_mnist.png | Bin examples/showcases/vae_reconstruction_celeba.png | Bin examples/showcases/vae_reconstruction_mnist.png | Bin .../showcases/vae_second_stage_sampling_celeba.png | Bin .../showcases/vae_second_stage_sampling_mnist.png | Bin .../showcases/vaegan_normal_sampling_celeba.png | Bin examples/showcases/vaegan_normal_sampling_mnist.png | Bin examples/showcases/vaegan_reconstruction_celeba.png | Bin examples/showcases/vaegan_reconstruction_mnist.png | Bin examples/showcases/vamp_reconstruction_celeba.png | Bin examples/showcases/vamp_reconstruction_mnist.png | Bin examples/showcases/vamp_vamp_sampling_celeba.png | Bin examples/showcases/vamp_vamp_sampling_mnist.png | Bin examples/showcases/vqvae_maf_sampling_celeba.png | Bin examples/showcases/vqvae_maf_sampling_mnist.png | Bin .../showcases/vqvae_pixelcnn_sampling_mnist.png | Bin examples/showcases/vqvae_reconstruction_celeba.png | Bin examples/showcases/vqvae_reconstruction_mnist.png | Bin examples/showcases/wae_gmm_sampling_celeba.png | Bin examples/showcases/wae_gmm_sampling_mnist.png | Bin examples/showcases/wae_normal_sampling_celeba.png | Bin examples/showcases/wae_normal_sampling_mnist.png | Bin examples/showcases/wae_reconstruction_celeba.png | Bin examples/showcases/wae_reconstruction_mnist.png | Bin pyproject.toml | 0 requirements.txt | 0 setup.cfg | 0 setup.py | 0 src/pythae/__init__.py | 0 src/pythae/config.py | 0 src/pythae/customexception.py | 0 src/pythae/data/__init__.py | 0 src/pythae/data/datasets.py | 0 src/pythae/data/preprocessors.py | 0 src/pythae/models/__init__.py | 0 src/pythae/models/adversarial_ae/__init__.py | 0 .../models/adversarial_ae/adversarial_ae_config.py | 0 .../models/adversarial_ae/adversarial_ae_model.py | 0 src/pythae/models/ae/__init__.py | 0 src/pythae/models/ae/ae_config.py | 0 src/pythae/models/ae/ae_model.py | 0 src/pythae/models/auto_model/__init__.py | 0 src/pythae/models/auto_model/auto_config.py | 0 src/pythae/models/auto_model/auto_model.py | 0 src/pythae/models/base/__init__.py | 0 src/pythae/models/base/base_config.py | 0 src/pythae/models/base/base_model.py | 0 src/pythae/models/base/base_utils.py | 0 src/pythae/models/beta_tc_vae/__init__.py | 0 src/pythae/models/beta_tc_vae/beta_tc_vae_config.py | 0 src/pythae/models/beta_tc_vae/beta_tc_vae_model.py | 0 src/pythae/models/beta_vae/__init__.py | 0 src/pythae/models/beta_vae/beta_vae_config.py | 0 src/pythae/models/beta_vae/beta_vae_model.py | 0 src/pythae/models/ciwae/__init__.py | 0 src/pythae/models/ciwae/ciwae_config.py | 0 src/pythae/models/ciwae/ciwae_model.py | 0 src/pythae/models/disentangled_beta_vae/__init__.py | 0 .../disentangled_beta_vae_config.py | 0 .../disentangled_beta_vae_model.py | 0 src/pythae/models/factor_vae/__init__.py | 0 src/pythae/models/factor_vae/factor_vae_config.py | 0 src/pythae/models/factor_vae/factor_vae_model.py | 0 src/pythae/models/factor_vae/factor_vae_utils.py | 0 src/pythae/models/hvae/__init__.py | 0 src/pythae/models/hvae/hvae_config.py | 0 src/pythae/models/hvae/hvae_model.py | 0 src/pythae/models/info_vae/__init__.py | 0 src/pythae/models/info_vae/info_vae_config.py | 0 src/pythae/models/info_vae/info_vae_model.py | 0 src/pythae/models/iwae/__init__.py | 0 src/pythae/models/iwae/iwae_config.py | 0 src/pythae/models/iwae/iwae_model.py | 0 src/pythae/models/miwae/__init__.py | 0 src/pythae/models/miwae/miwae_config.py | 0 src/pythae/models/miwae/miwae_model.py | 0 src/pythae/models/msssim_vae/__init__.py | 0 src/pythae/models/msssim_vae/msssim_vae_config.py | 0 src/pythae/models/msssim_vae/msssim_vae_model.py | 0 src/pythae/models/msssim_vae/msssim_vae_utils.py | 0 src/pythae/models/nn/__init__.py | 0 src/pythae/models/nn/base_architectures.py | 0 src/pythae/models/nn/benchmarks/__init__.py | 0 src/pythae/models/nn/benchmarks/celeba/__init__.py | 0 src/pythae/models/nn/benchmarks/celeba/convnets.py | 0 src/pythae/models/nn/benchmarks/celeba/resnets.py | 0 src/pythae/models/nn/benchmarks/cifar/__init__.py | 0 src/pythae/models/nn/benchmarks/cifar/convnets.py | 0 src/pythae/models/nn/benchmarks/cifar/resnets.py | 0 src/pythae/models/nn/benchmarks/mnist/__init__.py | 0 src/pythae/models/nn/benchmarks/mnist/convnets.py | 0 src/pythae/models/nn/benchmarks/mnist/resnets.py | 0 src/pythae/models/nn/benchmarks/utils.py | 0 src/pythae/models/nn/default_architectures.py | 0 src/pythae/models/normalizing_flows/__init__.py | 0 .../models/normalizing_flows/base/__init__.py | 0 .../models/normalizing_flows/base/base_nf_config.py | 0 .../models/normalizing_flows/base/base_nf_model.py | 0 src/pythae/models/normalizing_flows/iaf/__init__.py | 0 .../models/normalizing_flows/iaf/iaf_config.py | 0 .../models/normalizing_flows/iaf/iaf_model.py | 0 src/pythae/models/normalizing_flows/layers.py | 0 .../models/normalizing_flows/made/__init__.py | 0 .../models/normalizing_flows/made/made_config.py | 0 .../models/normalizing_flows/made/made_model.py | 0 src/pythae/models/normalizing_flows/maf/__init__.py | 0 .../models/normalizing_flows/maf/maf_config.py | 0 .../models/normalizing_flows/maf/maf_model.py | 0 .../models/normalizing_flows/pixelcnn/__init__.py | 0 .../normalizing_flows/pixelcnn/pixelcnn_config.py | 0 .../normalizing_flows/pixelcnn/pixelcnn_model.py | 0 .../models/normalizing_flows/pixelcnn/utils.py | 0 .../normalizing_flows/planar_flow/__init__.py | 0 .../planar_flow/planar_flow_config.py | 0 .../planar_flow/planar_flow_model.py | 0 .../normalizing_flows/radial_flow/__init__.py | 0 .../radial_flow/radial_flow_config.py | 0 .../radial_flow/radial_flow_model.py | 0 src/pythae/models/piwae/__init__.py | 0 src/pythae/models/piwae/piwae_config.py | 0 src/pythae/models/piwae/piwae_model.py | 0 src/pythae/models/pvae/__init__.py | 0 src/pythae/models/pvae/pvae_config.py | 0 src/pythae/models/pvae/pvae_model.py | 0 src/pythae/models/pvae/pvae_utils.py | 0 src/pythae/models/rae_gp/__init__.py | 0 src/pythae/models/rae_gp/rae_gp_config.py | 0 src/pythae/models/rae_gp/rae_gp_model.py | 0 src/pythae/models/rae_l2/__init__.py | 0 src/pythae/models/rae_l2/rae_l2_config.py | 0 src/pythae/models/rae_l2/rae_l2_model.py | 0 src/pythae/models/rhvae/__init__.py | 0 src/pythae/models/rhvae/rhvae_config.py | 0 src/pythae/models/rhvae/rhvae_model.py | 0 src/pythae/models/rhvae/rhvae_utils.py | 0 src/pythae/models/svae/__init__.py | 0 src/pythae/models/svae/svae_config.py | 0 src/pythae/models/svae/svae_model.py | 0 src/pythae/models/svae/svae_utils.py | 0 src/pythae/models/vae/__init__.py | 0 src/pythae/models/vae/vae_config.py | 0 src/pythae/models/vae/vae_model.py | 0 src/pythae/models/vae_gan/__init__.py | 0 src/pythae/models/vae_gan/vae_gan_config.py | 0 src/pythae/models/vae_gan/vae_gan_model.py | 0 src/pythae/models/vae_iaf/__init__.py | 0 src/pythae/models/vae_iaf/vae_iaf_config.py | 0 src/pythae/models/vae_iaf/vae_iaf_model.py | 0 src/pythae/models/vae_lin_nf/__init__.py | 0 src/pythae/models/vae_lin_nf/vae_lin_nf_config.py | 0 src/pythae/models/vae_lin_nf/vae_lin_nf_model.py | 0 src/pythae/models/vamp/__init__.py | 0 src/pythae/models/vamp/vamp_config.py | 0 src/pythae/models/vamp/vamp_model.py | 0 src/pythae/models/vq_vae/__init__.py | 0 src/pythae/models/vq_vae/vq_vae_config.py | 0 src/pythae/models/vq_vae/vq_vae_model.py | 0 src/pythae/models/vq_vae/vq_vae_utils.py | 0 src/pythae/models/wae_mmd/__init__.py | 0 src/pythae/models/wae_mmd/wae_mmd_config.py | 0 src/pythae/models/wae_mmd/wae_mmd_model.py | 0 src/pythae/pipelines/__init__.py | 0 src/pythae/pipelines/base_pipeline.py | 0 src/pythae/pipelines/generation.py | 0 src/pythae/pipelines/pipeline_utils.py | 0 src/pythae/pipelines/training.py | 0 src/pythae/py.typed | 0 src/pythae/samplers/__init__.py | 0 src/pythae/samplers/base/__init__.py | 0 src/pythae/samplers/base/base_sampler.py | 0 src/pythae/samplers/base/base_sampler_config.py | 0 src/pythae/samplers/gaussian_mixture/__init__.py | 0 .../gaussian_mixture/gaussian_mixture_config.py | 0 .../gaussian_mixture/gaussian_mixture_sampler.py | 0 .../hypersphere_uniform_sampler/__init__.py | 0 .../hypersphere_uniform_config.py | 0 .../hypersphere_uniform_sampler.py | 0 src/pythae/samplers/iaf_sampler/__init__.py | 0 src/pythae/samplers/iaf_sampler/iaf_sampler.py | 0 .../samplers/iaf_sampler/iaf_sampler_config.py | 0 src/pythae/samplers/maf_sampler/__init__.py | 0 src/pythae/samplers/maf_sampler/maf_sampler.py | 0 .../samplers/maf_sampler/maf_sampler_config.py | 0 src/pythae/samplers/manifold_sampler/__init__.py | 0 .../samplers/manifold_sampler/rhvae_sampler.py | 0 .../manifold_sampler/rhvae_sampler_config.py | 0 src/pythae/samplers/normal_sampling/__init__.py | 0 .../samplers/normal_sampling/normal_config.py | 0 .../samplers/normal_sampling/normal_sampler.py | 0 src/pythae/samplers/pixelcnn_sampler/__init__.py | 0 .../samplers/pixelcnn_sampler/pixelcnn_sampler.py | 0 .../pixelcnn_sampler/pixelcnn_sampler_config.py | 0 src/pythae/samplers/pvae_sampler/__init__.py | 0 src/pythae/samplers/pvae_sampler/pvae_sampler.py | 0 .../samplers/pvae_sampler/pvae_sampler_config.py | 0 .../samplers/two_stage_vae_sampler/__init__.py | 0 .../two_stage_vae_sampler/two_stage_sampler.py | 0 .../two_stage_sampler_config.py | 0 src/pythae/samplers/vamp_sampler/__init__.py | 0 src/pythae/samplers/vamp_sampler/vamp_sampler.py | 0 .../samplers/vamp_sampler/vamp_sampler_config.py | 0 src/pythae/trainers/__init__.py | 0 src/pythae/trainers/adversarial_trainer/__init__.py | 0 .../adversarial_trainer/adversarial_trainer.py | 0 .../adversarial_trainer_config.py | 0 src/pythae/trainers/base_trainer/__init__.py | 0 src/pythae/trainers/base_trainer/base_trainer.py | 0 .../trainers/base_trainer/base_training_config.py | 0 .../__init__.py | 0 .../coupled_optimizer_adversarial_trainer.py | 0 .../coupled_optimizer_adversarial_trainer_config.py | 0 .../trainers/coupled_optimizer_trainer/__init__.py | 0 .../coupled_optimizer_trainer.py | 0 .../coupled_optimizer_trainer_config.py | 0 src/pythae/trainers/trainer_utils.py | 0 src/pythae/trainers/training_callbacks.py | 0 tests/README.md | 0 tests/__init__.py | 0 tests/conftest.py | 0 .../data/baseAE/configs/corrupted_model_config.json | 0 tests/data/baseAE/configs/generation_config00.json | 0 tests/data/baseAE/configs/model_config00.json | 0 tests/data/baseAE/configs/not_json_file.md | 0 tests/data/baseAE/configs/training_config00.json | 0 tests/data/corrupted_config/model_config.json | 0 tests/data/custom_architectures.py | 0 tests/data/loading/dummy_data_folder/example0.bmp | Bin tests/data/loading/dummy_data_folder/example0.jpeg | Bin tests/data/loading/dummy_data_folder/example0.jpg | Bin tests/data/loading/dummy_data_folder/example0.png | Bin .../example0_downsampled_12_12.jpg | Bin tests/data/mnist_clean_train_dataset_sample | Bin tests/data/rhvae/configs/model_config00.json | 0 .../rhvae/configs/trained_model_folder/model.pt | 0 .../configs/trained_model_folder/model_config.json | 0 tests/data/unnormalized_mnist_data_array | Bin tests/data/unnormalized_mnist_data_list_of_array | Bin tests/pytest.ini | 0 tests/test_AE.py | 0 tests/test_Adversarial_AE.py | 0 tests/test_BetaTCVAE.py | 0 tests/test_BetaVAE.py | 0 tests/test_CIWAE.py | 0 tests/test_DisentangledBetaVAE.py | 0 tests/test_FactorVAE.py | 0 tests/test_HVAE.py | 0 tests/test_IAF.py | 0 tests/test_IWAE.py | 0 tests/test_MADE.py | 0 tests/test_MAF.py | 0 tests/test_MIWAE.py | 0 tests/test_MSSSIMVAE.py | 0 tests/test_PIWAE.py | 0 tests/test_PixelCNN.py | 0 tests/test_PoincareVAE.py | 0 tests/test_RHVAE.py | 0 tests/test_SVAE.py | 0 tests/test_VAE.py | 0 tests/test_VAEGAN.py | 0 tests/test_VAE_IAF.py | 0 tests/test_VAE_LinFlow.py | 0 tests/test_VAMP.py | 0 tests/test_VQVAE.py | 0 tests/test_WAE_MMD.py | 0 tests/test_adversarial_trainer.py | 0 tests/test_auto_model.py | 0 tests/test_baseAE.py | 0 tests/test_baseSampler.py | 0 tests/test_base_trainer.py | 0 tests/test_config.py | 0 .../test_coupled_optimizers_adversarial_trainer.py | 0 tests/test_coupled_optimizers_trainer.py | 0 tests/test_datasets.py | 0 tests/test_gaussian_mixture_sampler.py | 0 tests/test_hypersphere_uniform_sampler.py | 0 tests/test_iaf_sampler.py | 0 tests/test_info_vae_mmd.py | 0 tests/test_maf_sampler.py | 0 tests/test_nn_benchmark.py | 0 tests/test_normal_sampler.py | 0 tests/test_pipeline_standalone.py | 0 tests/test_pixelcnn_sampler.py | 0 tests/test_planar_flow.py | 0 tests/test_preprocessing.py | 0 tests/test_pvae_sampler.py | 0 tests/test_radial_flow.py | 0 tests/test_rae_gp.py | 0 tests/test_rae_l2.py | 0 tests/test_rhvae_sampler.py | 0 tests/test_training_callbacks.py | 0 tests/test_two_stage_sampler.py | 0 tests/test_vamp_sampler.py | 0 tests/your_file.jpeg | 0 605 files changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 .coveragerc mode change 100644 => 100755 .github/ISSUE_TEMPLATE/bug_report.md mode change 100644 => 100755 .github/ISSUE_TEMPLATE/feature_request.md mode change 100644 => 100755 .github/workflows/coverage.yml mode change 100644 => 100755 .github/workflows/tests_bench.yml mode change 100644 => 100755 .gitignore mode change 100644 => 100755 .readthedocs.yml mode change 100644 => 100755 CITATION.cff mode change 100644 => 100755 CONTRIBUTING.md mode change 100644 => 100755 LICENSE mode change 100644 => 100755 MANIFEST.in mode change 100644 => 100755 README.md mode change 100644 => 100755 codecov.yml mode change 100644 => 100755 docs/Makefile mode change 100644 => 100755 docs/README.md mode change 100644 => 100755 docs/make.bat mode change 100644 => 100755 docs/requirements.txt mode change 100644 => 100755 docs/source/_static/basic.css mode change 100644 => 100755 docs/source/_templates/_variables.scss mode change 100644 => 100755 docs/source/_templates/layout.html mode change 100644 => 100755 docs/source/conf.py mode change 100644 => 100755 docs/source/imgs/CNNs.jpeg mode change 100644 => 100755 docs/source/imgs/Case_study_1.jpg mode change 100644 => 100755 docs/source/imgs/DA_diagram(1).pdf mode change 100644 => 100755 docs/source/imgs/DA_diagram.png mode change 100644 => 100755 docs/source/imgs/baseline_results.png mode change 100644 => 100755 docs/source/imgs/logo_pyraug_2.jpeg mode change 100644 => 100755 docs/source/imgs/nine_digits-rot.pdf mode change 100644 => 100755 docs/source/imgs/nine_digits-rot.png mode change 100644 => 100755 docs/source/imgs/nine_digits.pdf mode change 100644 => 100755 docs/source/imgs/nine_digits.png mode change 100644 => 100755 docs/source/imgs/optimized_results.png mode change 100644 => 100755 docs/source/imgs/pyraug_diagram.jpg mode change 100644 => 100755 docs/source/imgs/pyraug_diagram_simplified.jpg mode change 100644 => 100755 docs/source/index.rst mode change 100644 => 100755 docs/source/models/autoencoders/aae.rst mode change 100644 => 100755 docs/source/models/autoencoders/ae.rst mode change 100644 => 100755 docs/source/models/autoencoders/auto_model.rst mode change 100644 => 100755 docs/source/models/autoencoders/baseAE.rst mode change 100644 => 100755 docs/source/models/autoencoders/betatcvae.rst mode change 100644 => 100755 docs/source/models/autoencoders/betavae.rst mode change 100644 => 100755 docs/source/models/autoencoders/ciwae.rst mode change 100644 => 100755 docs/source/models/autoencoders/disentangled_betavae.rst mode change 100644 => 100755 docs/source/models/autoencoders/factorvae.rst mode change 100644 => 100755 docs/source/models/autoencoders/hvae.rst mode change 100644 => 100755 docs/source/models/autoencoders/infovae.rst mode change 100644 => 100755 docs/source/models/autoencoders/iwae.rst mode change 100644 => 100755 docs/source/models/autoencoders/miwae.rst mode change 100644 => 100755 docs/source/models/autoencoders/models.rst mode change 100644 => 100755 docs/source/models/autoencoders/msssimvae.rst mode change 100644 => 100755 docs/source/models/autoencoders/piwae.rst mode change 100644 => 100755 docs/source/models/autoencoders/pvae.rst mode change 100644 => 100755 docs/source/models/autoencoders/rae_gp.rst mode change 100644 => 100755 docs/source/models/autoencoders/rae_l2.rst mode change 100644 => 100755 docs/source/models/autoencoders/rhvae.rst mode change 100644 => 100755 docs/source/models/autoencoders/svae.rst mode change 100644 => 100755 docs/source/models/autoencoders/vae.rst mode change 100644 => 100755 docs/source/models/autoencoders/vae_iaf.rst mode change 100644 => 100755 docs/source/models/autoencoders/vae_lin_nf.rst mode change 100644 => 100755 docs/source/models/autoencoders/vaegan.rst mode change 100644 => 100755 docs/source/models/autoencoders/vamp.rst mode change 100644 => 100755 docs/source/models/autoencoders/vqvae.rst mode change 100644 => 100755 docs/source/models/autoencoders/wae.rst mode change 100644 => 100755 docs/source/models/nn/celeba/convnets.rst mode change 100644 => 100755 docs/source/models/nn/celeba/pythae_benchmarks_nn_celeba.rst mode change 100644 => 100755 docs/source/models/nn/celeba/resnets.rst mode change 100644 => 100755 docs/source/models/nn/cifar/convnets.rst mode change 100644 => 100755 docs/source/models/nn/cifar/pythae_benchmarks_nn_cifar.rst mode change 100644 => 100755 docs/source/models/nn/cifar/resnets.rst mode change 100644 => 100755 docs/source/models/nn/mnist/convnets.rst mode change 100644 => 100755 docs/source/models/nn/mnist/pythae_benchmarks_nn_mnist.rst mode change 100644 => 100755 docs/source/models/nn/mnist/resnets.rst mode change 100644 => 100755 docs/source/models/nn/nn.rst mode change 100644 => 100755 docs/source/models/nn/pythae_base_nn.rst mode change 100644 => 100755 docs/source/models/nn/pythae_benchmarks_nn.rst mode change 100644 => 100755 docs/source/models/normalizing_flows/basenf.rst mode change 100644 => 100755 docs/source/models/normalizing_flows/iaf.rst mode change 100644 => 100755 docs/source/models/normalizing_flows/made.rst mode change 100644 => 100755 docs/source/models/normalizing_flows/maf.rst mode change 100644 => 100755 docs/source/models/normalizing_flows/normalizing_flows.rst mode change 100644 => 100755 docs/source/models/normalizing_flows/pixelcnn.rst mode change 100644 => 100755 docs/source/models/normalizing_flows/planar_flow.rst mode change 100644 => 100755 docs/source/models/normalizing_flows/radial_flow.rst mode change 100644 => 100755 docs/source/models/pythae.models.rst mode change 100644 => 100755 docs/source/pipelines/generation.rst mode change 100644 => 100755 docs/source/pipelines/pythae.pipelines.rst mode change 100644 => 100755 docs/source/pipelines/training.rst mode change 100644 => 100755 docs/source/references.bib mode change 100644 => 100755 docs/source/samplers/basesampler.rst mode change 100644 => 100755 docs/source/samplers/gmm_sampler.rst mode change 100644 => 100755 docs/source/samplers/iaf_sampler.rst mode change 100644 => 100755 docs/source/samplers/maf_sampler.rst mode change 100644 => 100755 docs/source/samplers/normal_sampler.rst mode change 100644 => 100755 docs/source/samplers/pixelcnn_sampler.rst mode change 100644 => 100755 docs/source/samplers/poincare_disk_sampler.rst mode change 100644 => 100755 docs/source/samplers/pythae.samplers.rst mode change 100644 => 100755 docs/source/samplers/rhvae_sampler.rst mode change 100644 => 100755 docs/source/samplers/twostage_sampler.rst mode change 100644 => 100755 docs/source/samplers/unit_sphere_unif_sampler.rst mode change 100644 => 100755 docs/source/samplers/vamp_sampler.rst mode change 100644 => 100755 docs/source/trainers/pythae.trainer.rst mode change 100644 => 100755 docs/source/trainers/pythae.trainers.adversarial_trainer.rst mode change 100644 => 100755 docs/source/trainers/pythae.trainers.base_trainer.rst mode change 100644 => 100755 docs/source/trainers/pythae.trainers.coupled_optimizer_adversarial_trainer.rst mode change 100644 => 100755 docs/source/trainers/pythae.trainers.coupled_trainer.rst mode change 100644 => 100755 docs/source/trainers/pythae.training_callbacks.rst mode change 100644 => 100755 examples/notebooks/README.md mode change 100644 => 100755 examples/notebooks/comet_experiment_monitoring.ipynb mode change 100644 => 100755 examples/notebooks/custom_dataset.ipynb mode change 100644 => 100755 examples/notebooks/hf_hub_models_sharing.ipynb mode change 100644 => 100755 examples/notebooks/making_your_own_autoencoder.ipynb mode change 100644 => 100755 examples/notebooks/mlflow_experiment_monitoring.ipynb mode change 100644 => 100755 examples/notebooks/models_training/adversarial_ae_training.ipynb mode change 100644 => 100755 examples/notebooks/models_training/ae_training.ipynb mode change 100644 => 100755 examples/notebooks/models_training/beta_tc_vae_training.ipynb mode change 100644 => 100755 examples/notebooks/models_training/beta_vae_training.ipynb mode change 100644 => 100755 examples/notebooks/models_training/ciwae_training.ipynb mode change 100644 => 100755 examples/notebooks/models_training/disentangled_beta_vae_training.ipynb mode change 100644 => 100755 examples/notebooks/models_training/factor_vae_training.ipynb mode change 100644 => 100755 examples/notebooks/models_training/hvae_training.ipynb mode change 100644 => 100755 examples/notebooks/models_training/info_vae_training.ipynb mode change 100644 => 100755 examples/notebooks/models_training/iwae_training.ipynb mode change 100644 => 100755 examples/notebooks/models_training/miwae_training.ipynb mode change 100644 => 100755 examples/notebooks/models_training/ms_ssim_vae_training.ipynb mode change 100644 => 100755 examples/notebooks/models_training/normalizing_flows_training.ipynb mode change 100644 => 100755 examples/notebooks/models_training/piwae_training.ipynb mode change 100644 => 100755 examples/notebooks/models_training/pvae_training.ipynb mode change 100644 => 100755 examples/notebooks/models_training/rae_gp_training.ipynb mode change 100644 => 100755 examples/notebooks/models_training/rae_l2_training.ipynb mode change 100644 => 100755 examples/notebooks/models_training/rhvae_training.ipynb mode change 100644 => 100755 examples/notebooks/models_training/svae_training.ipynb mode change 100644 => 100755 examples/notebooks/models_training/vae_iaf_training.ipynb mode change 100644 => 100755 examples/notebooks/models_training/vae_lin_nf_training.ipynb mode change 100644 => 100755 examples/notebooks/models_training/vae_lstm.ipynb mode change 100644 => 100755 examples/notebooks/models_training/vae_training.ipynb mode change 100644 => 100755 examples/notebooks/models_training/vaegan_training.ipynb mode change 100644 => 100755 examples/notebooks/models_training/vamp_training.ipynb mode change 100644 => 100755 examples/notebooks/models_training/vqvae_training.ipynb mode change 100644 => 100755 examples/notebooks/models_training/wae_training.ipynb mode change 100644 => 100755 examples/notebooks/overview_notebook.ipynb mode change 100644 => 100755 examples/notebooks/requirements.txt mode change 100644 => 100755 examples/notebooks/wandb_experiment_monitoring.ipynb mode change 100644 => 100755 examples/scripts/README.md mode change 100644 => 100755 examples/scripts/configs/binary_mnist/base_training_config.json mode change 100644 => 100755 examples/scripts/configs/binary_mnist/hvae_config.json mode change 100644 => 100755 examples/scripts/configs/binary_mnist/iwae_config.json mode change 100644 => 100755 examples/scripts/configs/binary_mnist/rhvae_config.json mode change 100644 => 100755 examples/scripts/configs/binary_mnist/svae_config.json mode change 100644 => 100755 examples/scripts/configs/binary_mnist/vae_config.json mode change 100644 => 100755 examples/scripts/configs/binary_mnist/vae_iaf_config.json mode change 100644 => 100755 examples/scripts/configs/binary_mnist/vae_lin_nf_config.json mode change 100644 => 100755 examples/scripts/configs/binary_mnist/vamp_config.json mode change 100644 => 100755 examples/scripts/configs/celeba/aae_config.json mode change 100644 => 100755 examples/scripts/configs/celeba/ae_config.json mode change 100644 => 100755 examples/scripts/configs/celeba/base_training_config.json mode change 100644 => 100755 examples/scripts/configs/celeba/beta_tc_vae_config.json mode change 100644 => 100755 examples/scripts/configs/celeba/beta_vae_config.json mode change 100644 => 100755 examples/scripts/configs/celeba/disentangled_beta_vae_config.json mode change 100644 => 100755 examples/scripts/configs/celeba/factor_vae_config.json mode change 100644 => 100755 examples/scripts/configs/celeba/hvae_config.json mode change 100644 => 100755 examples/scripts/configs/celeba/info_vae_config.json mode change 100644 => 100755 examples/scripts/configs/celeba/msssim_vae_config.json mode change 100644 => 100755 examples/scripts/configs/celeba/rae_gp_config.json mode change 100644 => 100755 examples/scripts/configs/celeba/rae_l2_config.json mode change 100644 => 100755 examples/scripts/configs/celeba/rhvae_config.json mode change 100644 => 100755 examples/scripts/configs/celeba/svae_config.json mode change 100644 => 100755 examples/scripts/configs/celeba/vae_config.json mode change 100644 => 100755 examples/scripts/configs/celeba/vae_iaf_config.json mode change 100644 => 100755 examples/scripts/configs/celeba/vae_lin_nf_config.json mode change 100644 => 100755 examples/scripts/configs/celeba/vaegan_config.json mode change 100644 => 100755 examples/scripts/configs/celeba/vamp_config.json mode change 100644 => 100755 examples/scripts/configs/celeba/vqvae_config.json mode change 100644 => 100755 examples/scripts/configs/celeba/wae_config.json mode change 100644 => 100755 examples/scripts/configs/cifar10/ae_config.json mode change 100644 => 100755 examples/scripts/configs/cifar10/base_training_config.json mode change 100644 => 100755 examples/scripts/configs/cifar10/beta_vae_config.json mode change 100644 => 100755 examples/scripts/configs/cifar10/hvae_config.json mode change 100644 => 100755 examples/scripts/configs/cifar10/info_vae_config.json mode change 100644 => 100755 examples/scripts/configs/cifar10/rae_gp_config.json mode change 100644 => 100755 examples/scripts/configs/cifar10/rae_l2_config.json mode change 100644 => 100755 examples/scripts/configs/cifar10/rhvae_config.json mode change 100644 => 100755 examples/scripts/configs/cifar10/vae_config.json mode change 100644 => 100755 examples/scripts/configs/cifar10/vamp_config.json mode change 100644 => 100755 examples/scripts/configs/cifar10/vqvae_config.json mode change 100644 => 100755 examples/scripts/configs/cifar10/wae_config.json mode change 100644 => 100755 examples/scripts/configs/dsprites/beta_tc_vae/base_training_config.json mode change 100644 => 100755 examples/scripts/configs/dsprites/beta_tc_vae/beta_tc_vae_config.json mode change 100644 => 100755 examples/scripts/configs/dsprites/factorvae/base_training_config.json mode change 100644 => 100755 examples/scripts/configs/dsprites/factorvae/factorvae_config.json mode change 100644 => 100755 examples/scripts/configs/mnist/aae_config.json mode change 100644 => 100755 examples/scripts/configs/mnist/ae_config.json mode change 100644 => 100755 examples/scripts/configs/mnist/base_training_config.json mode change 100644 => 100755 examples/scripts/configs/mnist/beta_tc_vae_config.json mode change 100644 => 100755 examples/scripts/configs/mnist/beta_vae_config.json mode change 100644 => 100755 examples/scripts/configs/mnist/disentangled_beta_vae_config.json mode change 100644 => 100755 examples/scripts/configs/mnist/factor_vae_config.json mode change 100644 => 100755 examples/scripts/configs/mnist/hvae_config.json mode change 100644 => 100755 examples/scripts/configs/mnist/info_vae_config.json mode change 100644 => 100755 examples/scripts/configs/mnist/iwae_config.json mode change 100644 => 100755 examples/scripts/configs/mnist/msssim_vae_config.json mode change 100644 => 100755 examples/scripts/configs/mnist/rae_gp_config.json mode change 100644 => 100755 examples/scripts/configs/mnist/rae_l2_config.json mode change 100644 => 100755 examples/scripts/configs/mnist/rhvae_config.json mode change 100644 => 100755 examples/scripts/configs/mnist/svae_config.json mode change 100644 => 100755 examples/scripts/configs/mnist/vae_config.json mode change 100644 => 100755 examples/scripts/configs/mnist/vae_iaf_config.json mode change 100644 => 100755 examples/scripts/configs/mnist/vae_lin_nf_config.json mode change 100644 => 100755 examples/scripts/configs/mnist/vaegan_config.json mode change 100644 => 100755 examples/scripts/configs/mnist/vamp_config.json mode change 100644 => 100755 examples/scripts/configs/mnist/vqvae_config.json mode change 100644 => 100755 examples/scripts/configs/mnist/wae_config.json mode change 100644 => 100755 examples/scripts/custom_nn.py mode change 100644 => 100755 examples/scripts/distributed_training_ffhq.py mode change 100644 => 100755 examples/scripts/distributed_training_imagenet.py mode change 100644 => 100755 examples/scripts/distributed_training_mnist.py mode change 100644 => 100755 examples/scripts/reproducibility/README.md mode change 100644 => 100755 examples/scripts/reproducibility/aae.py mode change 100644 => 100755 examples/scripts/reproducibility/betatcvae.py mode change 100644 => 100755 examples/scripts/reproducibility/ciwae.py mode change 100644 => 100755 examples/scripts/reproducibility/hvae.py mode change 100644 => 100755 examples/scripts/reproducibility/iwae.py mode change 100644 => 100755 examples/scripts/reproducibility/miwae.py mode change 100644 => 100755 examples/scripts/reproducibility/piwae.py mode change 100644 => 100755 examples/scripts/reproducibility/pvae.py mode change 100644 => 100755 examples/scripts/reproducibility/rae_gp.py mode change 100644 => 100755 examples/scripts/reproducibility/rae_l2.py mode change 100644 => 100755 examples/scripts/reproducibility/svae.py mode change 100644 => 100755 examples/scripts/reproducibility/vae.py mode change 100644 => 100755 examples/scripts/reproducibility/vamp.py mode change 100644 => 100755 examples/scripts/reproducibility/wae.py mode change 100644 => 100755 examples/scripts/training.py mode change 100644 => 100755 examples/showcases/aae_normal_sampling_celeba.png mode change 100644 => 100755 examples/showcases/aae_normal_sampling_mnist.png mode change 100644 => 100755 examples/showcases/aae_reconstruction_celeba.png mode change 100644 => 100755 examples/showcases/aae_reconstruction_mnist.png mode change 100644 => 100755 examples/showcases/ae_gmm_sampling_celeba.png mode change 100644 => 100755 examples/showcases/ae_gmm_sampling_mnist.png mode change 100644 => 100755 examples/showcases/ae_normal_sampling_celeba.png mode change 100644 => 100755 examples/showcases/ae_normal_sampling_mnist.png mode change 100644 => 100755 examples/showcases/ae_reconstruction_celeba.png mode change 100644 => 100755 examples/showcases/ae_reconstruction_mnist.png mode change 100644 => 100755 examples/showcases/beta_tc_vae_normal_sampling_celeba.png mode change 100644 => 100755 examples/showcases/beta_tc_vae_normal_sampling_mnist.png mode change 100644 => 100755 examples/showcases/beta_tc_vae_reconstruction_celeba.png mode change 100644 => 100755 examples/showcases/beta_tc_vae_reconstruction_mnist.png mode change 100644 => 100755 examples/showcases/beta_vae_normal_sampling_celeba.png mode change 100644 => 100755 examples/showcases/beta_vae_normal_sampling_mnist.png mode change 100644 => 100755 examples/showcases/beta_vae_reconstruction_celeba.png mode change 100644 => 100755 examples/showcases/beta_vae_reconstruction_mnist.png mode change 100644 => 100755 examples/showcases/disentangled_beta_vae_normal_sampling_celeba.png mode change 100644 => 100755 examples/showcases/disentangled_beta_vae_normal_sampling_mnist.png mode change 100644 => 100755 examples/showcases/disentangled_beta_vae_reconstruction_celeba.png mode change 100644 => 100755 examples/showcases/disentangled_beta_vae_reconstruction_mnist.png mode change 100644 => 100755 examples/showcases/eval_reconstruction_celeba.png mode change 100644 => 100755 examples/showcases/eval_reconstruction_mnist.png mode change 100644 => 100755 examples/showcases/factor_vae_normal_sampling_celeba.png mode change 100644 => 100755 examples/showcases/factor_vae_normal_sampling_mnist.png mode change 100644 => 100755 examples/showcases/factor_vae_reconstruction_celeba.png mode change 100644 => 100755 examples/showcases/factor_vae_reconstruction_mnist.png mode change 100644 => 100755 examples/showcases/hvae_gmm_sampling_mnist.png mode change 100644 => 100755 examples/showcases/hvae_normal_sampling_celeba.png mode change 100644 => 100755 examples/showcases/hvae_normal_sampling_mnist.png mode change 100644 => 100755 examples/showcases/hvae_reconstruction_celeba.png mode change 100644 => 100755 examples/showcases/hvae_reconstruction_mnist.png mode change 100644 => 100755 examples/showcases/hvvae_gmm_sampling_mnist.png mode change 100644 => 100755 examples/showcases/hvvae_normal_sampling_mnist.png mode change 100644 => 100755 examples/showcases/infovae_normal_sampling_celeba.png mode change 100644 => 100755 examples/showcases/infovae_normal_sampling_mnist.png mode change 100644 => 100755 examples/showcases/infovae_reconstruction_celeba.png mode change 100644 => 100755 examples/showcases/infovae_reconstruction_mnist.png mode change 100644 => 100755 examples/showcases/iwae_gmm_sampling_mnist.png mode change 100644 => 100755 examples/showcases/iwae_normal_sampling_celeba.png mode change 100644 => 100755 examples/showcases/iwae_normal_sampling_mnist.png mode change 100644 => 100755 examples/showcases/iwae_reconstruction_celeba.png mode change 100644 => 100755 examples/showcases/iwae_reconstruction_mnist.png mode change 100644 => 100755 examples/showcases/msssim_vae_normal_sampling_celeba.png mode change 100644 => 100755 examples/showcases/msssim_vae_normal_sampling_mnist.png mode change 100644 => 100755 examples/showcases/msssim_vae_reconstruction_celeba.png mode change 100644 => 100755 examples/showcases/msssim_vae_reconstruction_mnist.png mode change 100644 => 100755 examples/showcases/rae_gp_gmm_sampling_celeba.png mode change 100644 => 100755 examples/showcases/rae_gp_gmm_sampling_mnist.png mode change 100644 => 100755 examples/showcases/rae_gp_normal_sampling_celeba.png mode change 100644 => 100755 examples/showcases/rae_gp_normal_sampling_mnist.png mode change 100644 => 100755 examples/showcases/rae_gp_reconstruction_celeba.png mode change 100644 => 100755 examples/showcases/rae_gp_reconstruction_mnist.png mode change 100644 => 100755 examples/showcases/rae_l2_gmm_sampling_celeba.png mode change 100644 => 100755 examples/showcases/rae_l2_gmm_sampling_mnist.png mode change 100644 => 100755 examples/showcases/rae_l2_normal_sampling_celeba.png mode change 100644 => 100755 examples/showcases/rae_l2_normal_sampling_mnist.png mode change 100644 => 100755 examples/showcases/rae_l2_reconstruction_celeba.png mode change 100644 => 100755 examples/showcases/rae_l2_reconstruction_mnist.png mode change 100644 => 100755 examples/showcases/rhvae_reconstruction_celeba.png mode change 100644 => 100755 examples/showcases/rhvae_reconstruction_mnist.png mode change 100644 => 100755 examples/showcases/rhvae_rhvae_sampling_celeba.png mode change 100644 => 100755 examples/showcases/rhvae_rhvae_sampling_mnist.png mode change 100644 => 100755 examples/showcases/svae_hypersphere_uniform_sampling_celeba.png mode change 100644 => 100755 examples/showcases/svae_hypersphere_uniform_sampling_mnist.png mode change 100644 => 100755 examples/showcases/svae_reconstruction_celeba.png mode change 100644 => 100755 examples/showcases/svae_reconstruction_mnist.png mode change 100644 => 100755 examples/showcases/vae_gmm_sampling_celeba.png mode change 100644 => 100755 examples/showcases/vae_gmm_sampling_mnist.png mode change 100644 => 100755 examples/showcases/vae_iaf_normal_sampling_celeba.png mode change 100644 => 100755 examples/showcases/vae_iaf_normal_sampling_mnist.png mode change 100644 => 100755 examples/showcases/vae_iaf_reconstruction_celeba.png mode change 100644 => 100755 examples/showcases/vae_iaf_reconstruction_mnist.png mode change 100644 => 100755 examples/showcases/vae_lin_nf_normal_sampling_celeba.png mode change 100644 => 100755 examples/showcases/vae_lin_nf_normal_sampling_mnist.png mode change 100644 => 100755 examples/showcases/vae_lin_nf_reconstruction_celeba.png mode change 100644 => 100755 examples/showcases/vae_lin_nf_reconstruction_mnist.png mode change 100644 => 100755 examples/showcases/vae_maf_sampling_celeba.png mode change 100644 => 100755 examples/showcases/vae_maf_sampling_mnist.png mode change 100644 => 100755 examples/showcases/vae_normal_sampling_celeba.png mode change 100644 => 100755 examples/showcases/vae_normal_sampling_mnist.png mode change 100644 => 100755 examples/showcases/vae_reconstruction_celeba.png mode change 100644 => 100755 examples/showcases/vae_reconstruction_mnist.png mode change 100644 => 100755 examples/showcases/vae_second_stage_sampling_celeba.png mode change 100644 => 100755 examples/showcases/vae_second_stage_sampling_mnist.png mode change 100644 => 100755 examples/showcases/vaegan_normal_sampling_celeba.png mode change 100644 => 100755 examples/showcases/vaegan_normal_sampling_mnist.png mode change 100644 => 100755 examples/showcases/vaegan_reconstruction_celeba.png mode change 100644 => 100755 examples/showcases/vaegan_reconstruction_mnist.png mode change 100644 => 100755 examples/showcases/vamp_reconstruction_celeba.png mode change 100644 => 100755 examples/showcases/vamp_reconstruction_mnist.png mode change 100644 => 100755 examples/showcases/vamp_vamp_sampling_celeba.png mode change 100644 => 100755 examples/showcases/vamp_vamp_sampling_mnist.png mode change 100644 => 100755 examples/showcases/vqvae_maf_sampling_celeba.png mode change 100644 => 100755 examples/showcases/vqvae_maf_sampling_mnist.png mode change 100644 => 100755 examples/showcases/vqvae_pixelcnn_sampling_mnist.png mode change 100644 => 100755 examples/showcases/vqvae_reconstruction_celeba.png mode change 100644 => 100755 examples/showcases/vqvae_reconstruction_mnist.png mode change 100644 => 100755 examples/showcases/wae_gmm_sampling_celeba.png mode change 100644 => 100755 examples/showcases/wae_gmm_sampling_mnist.png mode change 100644 => 100755 examples/showcases/wae_normal_sampling_celeba.png mode change 100644 => 100755 examples/showcases/wae_normal_sampling_mnist.png mode change 100644 => 100755 examples/showcases/wae_reconstruction_celeba.png mode change 100644 => 100755 examples/showcases/wae_reconstruction_mnist.png mode change 100644 => 100755 pyproject.toml mode change 100644 => 100755 requirements.txt mode change 100644 => 100755 setup.cfg mode change 100644 => 100755 setup.py mode change 100644 => 100755 src/pythae/__init__.py mode change 100644 => 100755 src/pythae/config.py mode change 100644 => 100755 src/pythae/customexception.py mode change 100644 => 100755 src/pythae/data/__init__.py mode change 100644 => 100755 src/pythae/data/datasets.py mode change 100644 => 100755 src/pythae/data/preprocessors.py mode change 100644 => 100755 src/pythae/models/__init__.py mode change 100644 => 100755 src/pythae/models/adversarial_ae/__init__.py mode change 100644 => 100755 src/pythae/models/adversarial_ae/adversarial_ae_config.py mode change 100644 => 100755 src/pythae/models/adversarial_ae/adversarial_ae_model.py mode change 100644 => 100755 src/pythae/models/ae/__init__.py mode change 100644 => 100755 src/pythae/models/ae/ae_config.py mode change 100644 => 100755 src/pythae/models/ae/ae_model.py mode change 100644 => 100755 src/pythae/models/auto_model/__init__.py mode change 100644 => 100755 src/pythae/models/auto_model/auto_config.py mode change 100644 => 100755 src/pythae/models/auto_model/auto_model.py mode change 100644 => 100755 src/pythae/models/base/__init__.py mode change 100644 => 100755 src/pythae/models/base/base_config.py mode change 100644 => 100755 src/pythae/models/base/base_model.py mode change 100644 => 100755 src/pythae/models/base/base_utils.py mode change 100644 => 100755 src/pythae/models/beta_tc_vae/__init__.py mode change 100644 => 100755 src/pythae/models/beta_tc_vae/beta_tc_vae_config.py mode change 100644 => 100755 src/pythae/models/beta_tc_vae/beta_tc_vae_model.py mode change 100644 => 100755 src/pythae/models/beta_vae/__init__.py mode change 100644 => 100755 src/pythae/models/beta_vae/beta_vae_config.py mode change 100644 => 100755 src/pythae/models/beta_vae/beta_vae_model.py mode change 100644 => 100755 src/pythae/models/ciwae/__init__.py mode change 100644 => 100755 src/pythae/models/ciwae/ciwae_config.py mode change 100644 => 100755 src/pythae/models/ciwae/ciwae_model.py mode change 100644 => 100755 src/pythae/models/disentangled_beta_vae/__init__.py mode change 100644 => 100755 src/pythae/models/disentangled_beta_vae/disentangled_beta_vae_config.py mode change 100644 => 100755 src/pythae/models/disentangled_beta_vae/disentangled_beta_vae_model.py mode change 100644 => 100755 src/pythae/models/factor_vae/__init__.py mode change 100644 => 100755 src/pythae/models/factor_vae/factor_vae_config.py mode change 100644 => 100755 src/pythae/models/factor_vae/factor_vae_model.py mode change 100644 => 100755 src/pythae/models/factor_vae/factor_vae_utils.py mode change 100644 => 100755 src/pythae/models/hvae/__init__.py mode change 100644 => 100755 src/pythae/models/hvae/hvae_config.py mode change 100644 => 100755 src/pythae/models/hvae/hvae_model.py mode change 100644 => 100755 src/pythae/models/info_vae/__init__.py mode change 100644 => 100755 src/pythae/models/info_vae/info_vae_config.py mode change 100644 => 100755 src/pythae/models/info_vae/info_vae_model.py mode change 100644 => 100755 src/pythae/models/iwae/__init__.py mode change 100644 => 100755 src/pythae/models/iwae/iwae_config.py mode change 100644 => 100755 src/pythae/models/iwae/iwae_model.py mode change 100644 => 100755 src/pythae/models/miwae/__init__.py mode change 100644 => 100755 src/pythae/models/miwae/miwae_config.py mode change 100644 => 100755 src/pythae/models/miwae/miwae_model.py mode change 100644 => 100755 src/pythae/models/msssim_vae/__init__.py mode change 100644 => 100755 src/pythae/models/msssim_vae/msssim_vae_config.py mode change 100644 => 100755 src/pythae/models/msssim_vae/msssim_vae_model.py mode change 100644 => 100755 src/pythae/models/msssim_vae/msssim_vae_utils.py mode change 100644 => 100755 src/pythae/models/nn/__init__.py mode change 100644 => 100755 src/pythae/models/nn/base_architectures.py mode change 100644 => 100755 src/pythae/models/nn/benchmarks/__init__.py mode change 100644 => 100755 src/pythae/models/nn/benchmarks/celeba/__init__.py mode change 100644 => 100755 src/pythae/models/nn/benchmarks/celeba/convnets.py mode change 100644 => 100755 src/pythae/models/nn/benchmarks/celeba/resnets.py mode change 100644 => 100755 src/pythae/models/nn/benchmarks/cifar/__init__.py mode change 100644 => 100755 src/pythae/models/nn/benchmarks/cifar/convnets.py mode change 100644 => 100755 src/pythae/models/nn/benchmarks/cifar/resnets.py mode change 100644 => 100755 src/pythae/models/nn/benchmarks/mnist/__init__.py mode change 100644 => 100755 src/pythae/models/nn/benchmarks/mnist/convnets.py mode change 100644 => 100755 src/pythae/models/nn/benchmarks/mnist/resnets.py mode change 100644 => 100755 src/pythae/models/nn/benchmarks/utils.py mode change 100644 => 100755 src/pythae/models/nn/default_architectures.py mode change 100644 => 100755 src/pythae/models/normalizing_flows/__init__.py mode change 100644 => 100755 src/pythae/models/normalizing_flows/base/__init__.py mode change 100644 => 100755 src/pythae/models/normalizing_flows/base/base_nf_config.py mode change 100644 => 100755 src/pythae/models/normalizing_flows/base/base_nf_model.py mode change 100644 => 100755 src/pythae/models/normalizing_flows/iaf/__init__.py mode change 100644 => 100755 src/pythae/models/normalizing_flows/iaf/iaf_config.py mode change 100644 => 100755 src/pythae/models/normalizing_flows/iaf/iaf_model.py mode change 100644 => 100755 src/pythae/models/normalizing_flows/layers.py mode change 100644 => 100755 src/pythae/models/normalizing_flows/made/__init__.py mode change 100644 => 100755 src/pythae/models/normalizing_flows/made/made_config.py mode change 100644 => 100755 src/pythae/models/normalizing_flows/made/made_model.py mode change 100644 => 100755 src/pythae/models/normalizing_flows/maf/__init__.py mode change 100644 => 100755 src/pythae/models/normalizing_flows/maf/maf_config.py mode change 100644 => 100755 src/pythae/models/normalizing_flows/maf/maf_model.py mode change 100644 => 100755 src/pythae/models/normalizing_flows/pixelcnn/__init__.py mode change 100644 => 100755 src/pythae/models/normalizing_flows/pixelcnn/pixelcnn_config.py mode change 100644 => 100755 src/pythae/models/normalizing_flows/pixelcnn/pixelcnn_model.py mode change 100644 => 100755 src/pythae/models/normalizing_flows/pixelcnn/utils.py mode change 100644 => 100755 src/pythae/models/normalizing_flows/planar_flow/__init__.py mode change 100644 => 100755 src/pythae/models/normalizing_flows/planar_flow/planar_flow_config.py mode change 100644 => 100755 src/pythae/models/normalizing_flows/planar_flow/planar_flow_model.py mode change 100644 => 100755 src/pythae/models/normalizing_flows/radial_flow/__init__.py mode change 100644 => 100755 src/pythae/models/normalizing_flows/radial_flow/radial_flow_config.py mode change 100644 => 100755 src/pythae/models/normalizing_flows/radial_flow/radial_flow_model.py mode change 100644 => 100755 src/pythae/models/piwae/__init__.py mode change 100644 => 100755 src/pythae/models/piwae/piwae_config.py mode change 100644 => 100755 src/pythae/models/piwae/piwae_model.py mode change 100644 => 100755 src/pythae/models/pvae/__init__.py mode change 100644 => 100755 src/pythae/models/pvae/pvae_config.py mode change 100644 => 100755 src/pythae/models/pvae/pvae_model.py mode change 100644 => 100755 src/pythae/models/pvae/pvae_utils.py mode change 100644 => 100755 src/pythae/models/rae_gp/__init__.py mode change 100644 => 100755 src/pythae/models/rae_gp/rae_gp_config.py mode change 100644 => 100755 src/pythae/models/rae_gp/rae_gp_model.py mode change 100644 => 100755 src/pythae/models/rae_l2/__init__.py mode change 100644 => 100755 src/pythae/models/rae_l2/rae_l2_config.py mode change 100644 => 100755 src/pythae/models/rae_l2/rae_l2_model.py mode change 100644 => 100755 src/pythae/models/rhvae/__init__.py mode change 100644 => 100755 src/pythae/models/rhvae/rhvae_config.py mode change 100644 => 100755 src/pythae/models/rhvae/rhvae_model.py mode change 100644 => 100755 src/pythae/models/rhvae/rhvae_utils.py mode change 100644 => 100755 src/pythae/models/svae/__init__.py mode change 100644 => 100755 src/pythae/models/svae/svae_config.py mode change 100644 => 100755 src/pythae/models/svae/svae_model.py mode change 100644 => 100755 src/pythae/models/svae/svae_utils.py mode change 100644 => 100755 src/pythae/models/vae/__init__.py mode change 100644 => 100755 src/pythae/models/vae/vae_config.py mode change 100644 => 100755 src/pythae/models/vae/vae_model.py mode change 100644 => 100755 src/pythae/models/vae_gan/__init__.py mode change 100644 => 100755 src/pythae/models/vae_gan/vae_gan_config.py mode change 100644 => 100755 src/pythae/models/vae_gan/vae_gan_model.py mode change 100644 => 100755 src/pythae/models/vae_iaf/__init__.py mode change 100644 => 100755 src/pythae/models/vae_iaf/vae_iaf_config.py mode change 100644 => 100755 src/pythae/models/vae_iaf/vae_iaf_model.py mode change 100644 => 100755 src/pythae/models/vae_lin_nf/__init__.py mode change 100644 => 100755 src/pythae/models/vae_lin_nf/vae_lin_nf_config.py mode change 100644 => 100755 src/pythae/models/vae_lin_nf/vae_lin_nf_model.py mode change 100644 => 100755 src/pythae/models/vamp/__init__.py mode change 100644 => 100755 src/pythae/models/vamp/vamp_config.py mode change 100644 => 100755 src/pythae/models/vamp/vamp_model.py mode change 100644 => 100755 src/pythae/models/vq_vae/__init__.py mode change 100644 => 100755 src/pythae/models/vq_vae/vq_vae_config.py mode change 100644 => 100755 src/pythae/models/vq_vae/vq_vae_model.py mode change 100644 => 100755 src/pythae/models/vq_vae/vq_vae_utils.py mode change 100644 => 100755 src/pythae/models/wae_mmd/__init__.py mode change 100644 => 100755 src/pythae/models/wae_mmd/wae_mmd_config.py mode change 100644 => 100755 src/pythae/models/wae_mmd/wae_mmd_model.py mode change 100644 => 100755 src/pythae/pipelines/__init__.py mode change 100644 => 100755 src/pythae/pipelines/base_pipeline.py mode change 100644 => 100755 src/pythae/pipelines/generation.py mode change 100644 => 100755 src/pythae/pipelines/pipeline_utils.py mode change 100644 => 100755 src/pythae/pipelines/training.py mode change 100644 => 100755 src/pythae/py.typed mode change 100644 => 100755 src/pythae/samplers/__init__.py mode change 100644 => 100755 src/pythae/samplers/base/__init__.py mode change 100644 => 100755 src/pythae/samplers/base/base_sampler.py mode change 100644 => 100755 src/pythae/samplers/base/base_sampler_config.py mode change 100644 => 100755 src/pythae/samplers/gaussian_mixture/__init__.py mode change 100644 => 100755 src/pythae/samplers/gaussian_mixture/gaussian_mixture_config.py mode change 100644 => 100755 src/pythae/samplers/gaussian_mixture/gaussian_mixture_sampler.py mode change 100644 => 100755 src/pythae/samplers/hypersphere_uniform_sampler/__init__.py mode change 100644 => 100755 src/pythae/samplers/hypersphere_uniform_sampler/hypersphere_uniform_config.py mode change 100644 => 100755 src/pythae/samplers/hypersphere_uniform_sampler/hypersphere_uniform_sampler.py mode change 100644 => 100755 src/pythae/samplers/iaf_sampler/__init__.py mode change 100644 => 100755 src/pythae/samplers/iaf_sampler/iaf_sampler.py mode change 100644 => 100755 src/pythae/samplers/iaf_sampler/iaf_sampler_config.py mode change 100644 => 100755 src/pythae/samplers/maf_sampler/__init__.py mode change 100644 => 100755 src/pythae/samplers/maf_sampler/maf_sampler.py mode change 100644 => 100755 src/pythae/samplers/maf_sampler/maf_sampler_config.py mode change 100644 => 100755 src/pythae/samplers/manifold_sampler/__init__.py mode change 100644 => 100755 src/pythae/samplers/manifold_sampler/rhvae_sampler.py mode change 100644 => 100755 src/pythae/samplers/manifold_sampler/rhvae_sampler_config.py mode change 100644 => 100755 src/pythae/samplers/normal_sampling/__init__.py mode change 100644 => 100755 src/pythae/samplers/normal_sampling/normal_config.py mode change 100644 => 100755 src/pythae/samplers/normal_sampling/normal_sampler.py mode change 100644 => 100755 src/pythae/samplers/pixelcnn_sampler/__init__.py mode change 100644 => 100755 src/pythae/samplers/pixelcnn_sampler/pixelcnn_sampler.py mode change 100644 => 100755 src/pythae/samplers/pixelcnn_sampler/pixelcnn_sampler_config.py mode change 100644 => 100755 src/pythae/samplers/pvae_sampler/__init__.py mode change 100644 => 100755 src/pythae/samplers/pvae_sampler/pvae_sampler.py mode change 100644 => 100755 src/pythae/samplers/pvae_sampler/pvae_sampler_config.py mode change 100644 => 100755 src/pythae/samplers/two_stage_vae_sampler/__init__.py mode change 100644 => 100755 src/pythae/samplers/two_stage_vae_sampler/two_stage_sampler.py mode change 100644 => 100755 src/pythae/samplers/two_stage_vae_sampler/two_stage_sampler_config.py mode change 100644 => 100755 src/pythae/samplers/vamp_sampler/__init__.py mode change 100644 => 100755 src/pythae/samplers/vamp_sampler/vamp_sampler.py mode change 100644 => 100755 src/pythae/samplers/vamp_sampler/vamp_sampler_config.py mode change 100644 => 100755 src/pythae/trainers/__init__.py mode change 100644 => 100755 src/pythae/trainers/adversarial_trainer/__init__.py mode change 100644 => 100755 src/pythae/trainers/adversarial_trainer/adversarial_trainer.py mode change 100644 => 100755 src/pythae/trainers/adversarial_trainer/adversarial_trainer_config.py mode change 100644 => 100755 src/pythae/trainers/base_trainer/__init__.py mode change 100644 => 100755 src/pythae/trainers/base_trainer/base_trainer.py mode change 100644 => 100755 src/pythae/trainers/base_trainer/base_training_config.py mode change 100644 => 100755 src/pythae/trainers/coupled_optimizer_adversarial_trainer/__init__.py mode change 100644 => 100755 src/pythae/trainers/coupled_optimizer_adversarial_trainer/coupled_optimizer_adversarial_trainer.py mode change 100644 => 100755 src/pythae/trainers/coupled_optimizer_adversarial_trainer/coupled_optimizer_adversarial_trainer_config.py mode change 100644 => 100755 src/pythae/trainers/coupled_optimizer_trainer/__init__.py mode change 100644 => 100755 src/pythae/trainers/coupled_optimizer_trainer/coupled_optimizer_trainer.py mode change 100644 => 100755 src/pythae/trainers/coupled_optimizer_trainer/coupled_optimizer_trainer_config.py mode change 100644 => 100755 src/pythae/trainers/trainer_utils.py mode change 100644 => 100755 src/pythae/trainers/training_callbacks.py mode change 100644 => 100755 tests/README.md mode change 100644 => 100755 tests/__init__.py mode change 100644 => 100755 tests/conftest.py mode change 100644 => 100755 tests/data/baseAE/configs/corrupted_model_config.json mode change 100644 => 100755 tests/data/baseAE/configs/generation_config00.json mode change 100644 => 100755 tests/data/baseAE/configs/model_config00.json mode change 100644 => 100755 tests/data/baseAE/configs/not_json_file.md mode change 100644 => 100755 tests/data/baseAE/configs/training_config00.json mode change 100644 => 100755 tests/data/corrupted_config/model_config.json mode change 100644 => 100755 tests/data/custom_architectures.py mode change 100644 => 100755 tests/data/loading/dummy_data_folder/example0.bmp mode change 100644 => 100755 tests/data/loading/dummy_data_folder/example0.jpeg mode change 100644 => 100755 tests/data/loading/dummy_data_folder/example0.jpg mode change 100644 => 100755 tests/data/loading/dummy_data_folder/example0.png mode change 100644 => 100755 tests/data/loading/dummy_data_folder/example0_downsampled_12_12.jpg mode change 100644 => 100755 tests/data/mnist_clean_train_dataset_sample mode change 100644 => 100755 tests/data/rhvae/configs/model_config00.json mode change 100644 => 100755 tests/data/rhvae/configs/trained_model_folder/model.pt mode change 100644 => 100755 tests/data/rhvae/configs/trained_model_folder/model_config.json mode change 100644 => 100755 tests/data/unnormalized_mnist_data_array mode change 100644 => 100755 tests/data/unnormalized_mnist_data_list_of_array mode change 100644 => 100755 tests/pytest.ini mode change 100644 => 100755 tests/test_AE.py mode change 100644 => 100755 tests/test_Adversarial_AE.py mode change 100644 => 100755 tests/test_BetaTCVAE.py mode change 100644 => 100755 tests/test_BetaVAE.py mode change 100644 => 100755 tests/test_CIWAE.py mode change 100644 => 100755 tests/test_DisentangledBetaVAE.py mode change 100644 => 100755 tests/test_FactorVAE.py mode change 100644 => 100755 tests/test_HVAE.py mode change 100644 => 100755 tests/test_IAF.py mode change 100644 => 100755 tests/test_IWAE.py mode change 100644 => 100755 tests/test_MADE.py mode change 100644 => 100755 tests/test_MAF.py mode change 100644 => 100755 tests/test_MIWAE.py mode change 100644 => 100755 tests/test_MSSSIMVAE.py mode change 100644 => 100755 tests/test_PIWAE.py mode change 100644 => 100755 tests/test_PixelCNN.py mode change 100644 => 100755 tests/test_PoincareVAE.py mode change 100644 => 100755 tests/test_RHVAE.py mode change 100644 => 100755 tests/test_SVAE.py mode change 100644 => 100755 tests/test_VAE.py mode change 100644 => 100755 tests/test_VAEGAN.py mode change 100644 => 100755 tests/test_VAE_IAF.py mode change 100644 => 100755 tests/test_VAE_LinFlow.py mode change 100644 => 100755 tests/test_VAMP.py mode change 100644 => 100755 tests/test_VQVAE.py mode change 100644 => 100755 tests/test_WAE_MMD.py mode change 100644 => 100755 tests/test_adversarial_trainer.py mode change 100644 => 100755 tests/test_auto_model.py mode change 100644 => 100755 tests/test_baseAE.py mode change 100644 => 100755 tests/test_baseSampler.py mode change 100644 => 100755 tests/test_base_trainer.py mode change 100644 => 100755 tests/test_config.py mode change 100644 => 100755 tests/test_coupled_optimizers_adversarial_trainer.py mode change 100644 => 100755 tests/test_coupled_optimizers_trainer.py mode change 100644 => 100755 tests/test_datasets.py mode change 100644 => 100755 tests/test_gaussian_mixture_sampler.py mode change 100644 => 100755 tests/test_hypersphere_uniform_sampler.py mode change 100644 => 100755 tests/test_iaf_sampler.py mode change 100644 => 100755 tests/test_info_vae_mmd.py mode change 100644 => 100755 tests/test_maf_sampler.py mode change 100644 => 100755 tests/test_nn_benchmark.py mode change 100644 => 100755 tests/test_normal_sampler.py mode change 100644 => 100755 tests/test_pipeline_standalone.py mode change 100644 => 100755 tests/test_pixelcnn_sampler.py mode change 100644 => 100755 tests/test_planar_flow.py mode change 100644 => 100755 tests/test_preprocessing.py mode change 100644 => 100755 tests/test_pvae_sampler.py mode change 100644 => 100755 tests/test_radial_flow.py mode change 100644 => 100755 tests/test_rae_gp.py mode change 100644 => 100755 tests/test_rae_l2.py mode change 100644 => 100755 tests/test_rhvae_sampler.py mode change 100644 => 100755 tests/test_training_callbacks.py mode change 100644 => 100755 tests/test_two_stage_sampler.py mode change 100644 => 100755 tests/test_vamp_sampler.py mode change 100644 => 100755 tests/your_file.jpeg diff --git a/.coveragerc b/.coveragerc old mode 100644 new mode 100755 diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md old mode 100644 new mode 100755 diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md old mode 100644 new mode 100755 diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml old mode 100644 new mode 100755 diff --git a/.github/workflows/tests_bench.yml b/.github/workflows/tests_bench.yml old mode 100644 new mode 100755 diff --git a/.gitignore b/.gitignore old mode 100644 new mode 100755 diff --git a/.readthedocs.yml b/.readthedocs.yml old mode 100644 new mode 100755 diff --git a/CITATION.cff b/CITATION.cff old mode 100644 new mode 100755 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md old mode 100644 new mode 100755 diff --git a/LICENSE b/LICENSE old mode 100644 new mode 100755 diff --git a/MANIFEST.in b/MANIFEST.in old mode 100644 new mode 100755 diff --git a/README.md b/README.md old mode 100644 new mode 100755 diff --git a/codecov.yml b/codecov.yml old mode 100644 new mode 100755 diff --git a/docs/Makefile b/docs/Makefile old mode 100644 new mode 100755 diff --git a/docs/README.md b/docs/README.md old mode 100644 new mode 100755 diff --git a/docs/make.bat b/docs/make.bat old mode 100644 new mode 100755 diff --git a/docs/requirements.txt b/docs/requirements.txt old mode 100644 new mode 100755 diff --git a/docs/source/_static/basic.css b/docs/source/_static/basic.css old mode 100644 new mode 100755 diff --git a/docs/source/_templates/_variables.scss b/docs/source/_templates/_variables.scss old mode 100644 new mode 100755 diff --git a/docs/source/_templates/layout.html b/docs/source/_templates/layout.html old mode 100644 new mode 100755 diff --git a/docs/source/conf.py b/docs/source/conf.py old mode 100644 new mode 100755 diff --git a/docs/source/imgs/CNNs.jpeg b/docs/source/imgs/CNNs.jpeg old mode 100644 new mode 100755 diff --git a/docs/source/imgs/Case_study_1.jpg b/docs/source/imgs/Case_study_1.jpg old mode 100644 new mode 100755 diff --git a/docs/source/imgs/DA_diagram(1).pdf b/docs/source/imgs/DA_diagram(1).pdf old mode 100644 new mode 100755 diff --git a/docs/source/imgs/DA_diagram.png b/docs/source/imgs/DA_diagram.png old mode 100644 new mode 100755 diff --git a/docs/source/imgs/baseline_results.png b/docs/source/imgs/baseline_results.png old mode 100644 new mode 100755 diff --git a/docs/source/imgs/logo_pyraug_2.jpeg b/docs/source/imgs/logo_pyraug_2.jpeg old mode 100644 new mode 100755 diff --git a/docs/source/imgs/nine_digits-rot.pdf b/docs/source/imgs/nine_digits-rot.pdf old mode 100644 new mode 100755 diff --git a/docs/source/imgs/nine_digits-rot.png b/docs/source/imgs/nine_digits-rot.png old mode 100644 new mode 100755 diff --git a/docs/source/imgs/nine_digits.pdf b/docs/source/imgs/nine_digits.pdf old mode 100644 new mode 100755 diff --git a/docs/source/imgs/nine_digits.png b/docs/source/imgs/nine_digits.png old mode 100644 new mode 100755 diff --git a/docs/source/imgs/optimized_results.png b/docs/source/imgs/optimized_results.png old mode 100644 new mode 100755 diff --git a/docs/source/imgs/pyraug_diagram.jpg b/docs/source/imgs/pyraug_diagram.jpg old mode 100644 new mode 100755 diff --git a/docs/source/imgs/pyraug_diagram_simplified.jpg b/docs/source/imgs/pyraug_diagram_simplified.jpg old mode 100644 new mode 100755 diff --git a/docs/source/index.rst b/docs/source/index.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/autoencoders/aae.rst b/docs/source/models/autoencoders/aae.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/autoencoders/ae.rst b/docs/source/models/autoencoders/ae.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/autoencoders/auto_model.rst b/docs/source/models/autoencoders/auto_model.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/autoencoders/baseAE.rst b/docs/source/models/autoencoders/baseAE.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/autoencoders/betatcvae.rst b/docs/source/models/autoencoders/betatcvae.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/autoencoders/betavae.rst b/docs/source/models/autoencoders/betavae.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/autoencoders/ciwae.rst b/docs/source/models/autoencoders/ciwae.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/autoencoders/disentangled_betavae.rst b/docs/source/models/autoencoders/disentangled_betavae.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/autoencoders/factorvae.rst b/docs/source/models/autoencoders/factorvae.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/autoencoders/hvae.rst b/docs/source/models/autoencoders/hvae.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/autoencoders/infovae.rst b/docs/source/models/autoencoders/infovae.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/autoencoders/iwae.rst b/docs/source/models/autoencoders/iwae.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/autoencoders/miwae.rst b/docs/source/models/autoencoders/miwae.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/autoencoders/models.rst b/docs/source/models/autoencoders/models.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/autoencoders/msssimvae.rst b/docs/source/models/autoencoders/msssimvae.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/autoencoders/piwae.rst b/docs/source/models/autoencoders/piwae.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/autoencoders/pvae.rst b/docs/source/models/autoencoders/pvae.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/autoencoders/rae_gp.rst b/docs/source/models/autoencoders/rae_gp.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/autoencoders/rae_l2.rst b/docs/source/models/autoencoders/rae_l2.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/autoencoders/rhvae.rst b/docs/source/models/autoencoders/rhvae.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/autoencoders/svae.rst b/docs/source/models/autoencoders/svae.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/autoencoders/vae.rst b/docs/source/models/autoencoders/vae.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/autoencoders/vae_iaf.rst b/docs/source/models/autoencoders/vae_iaf.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/autoencoders/vae_lin_nf.rst b/docs/source/models/autoencoders/vae_lin_nf.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/autoencoders/vaegan.rst b/docs/source/models/autoencoders/vaegan.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/autoencoders/vamp.rst b/docs/source/models/autoencoders/vamp.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/autoencoders/vqvae.rst b/docs/source/models/autoencoders/vqvae.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/autoencoders/wae.rst b/docs/source/models/autoencoders/wae.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/nn/celeba/convnets.rst b/docs/source/models/nn/celeba/convnets.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/nn/celeba/pythae_benchmarks_nn_celeba.rst b/docs/source/models/nn/celeba/pythae_benchmarks_nn_celeba.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/nn/celeba/resnets.rst b/docs/source/models/nn/celeba/resnets.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/nn/cifar/convnets.rst b/docs/source/models/nn/cifar/convnets.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/nn/cifar/pythae_benchmarks_nn_cifar.rst b/docs/source/models/nn/cifar/pythae_benchmarks_nn_cifar.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/nn/cifar/resnets.rst b/docs/source/models/nn/cifar/resnets.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/nn/mnist/convnets.rst b/docs/source/models/nn/mnist/convnets.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/nn/mnist/pythae_benchmarks_nn_mnist.rst b/docs/source/models/nn/mnist/pythae_benchmarks_nn_mnist.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/nn/mnist/resnets.rst b/docs/source/models/nn/mnist/resnets.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/nn/nn.rst b/docs/source/models/nn/nn.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/nn/pythae_base_nn.rst b/docs/source/models/nn/pythae_base_nn.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/nn/pythae_benchmarks_nn.rst b/docs/source/models/nn/pythae_benchmarks_nn.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/normalizing_flows/basenf.rst b/docs/source/models/normalizing_flows/basenf.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/normalizing_flows/iaf.rst b/docs/source/models/normalizing_flows/iaf.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/normalizing_flows/made.rst b/docs/source/models/normalizing_flows/made.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/normalizing_flows/maf.rst b/docs/source/models/normalizing_flows/maf.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/normalizing_flows/normalizing_flows.rst b/docs/source/models/normalizing_flows/normalizing_flows.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/normalizing_flows/pixelcnn.rst b/docs/source/models/normalizing_flows/pixelcnn.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/normalizing_flows/planar_flow.rst b/docs/source/models/normalizing_flows/planar_flow.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/normalizing_flows/radial_flow.rst b/docs/source/models/normalizing_flows/radial_flow.rst old mode 100644 new mode 100755 diff --git a/docs/source/models/pythae.models.rst b/docs/source/models/pythae.models.rst old mode 100644 new mode 100755 diff --git a/docs/source/pipelines/generation.rst b/docs/source/pipelines/generation.rst old mode 100644 new mode 100755 diff --git a/docs/source/pipelines/pythae.pipelines.rst b/docs/source/pipelines/pythae.pipelines.rst old mode 100644 new mode 100755 diff --git a/docs/source/pipelines/training.rst b/docs/source/pipelines/training.rst old mode 100644 new mode 100755 diff --git a/docs/source/references.bib b/docs/source/references.bib old mode 100644 new mode 100755 diff --git a/docs/source/samplers/basesampler.rst b/docs/source/samplers/basesampler.rst old mode 100644 new mode 100755 diff --git a/docs/source/samplers/gmm_sampler.rst b/docs/source/samplers/gmm_sampler.rst old mode 100644 new mode 100755 diff --git a/docs/source/samplers/iaf_sampler.rst b/docs/source/samplers/iaf_sampler.rst old mode 100644 new mode 100755 diff --git a/docs/source/samplers/maf_sampler.rst b/docs/source/samplers/maf_sampler.rst old mode 100644 new mode 100755 diff --git a/docs/source/samplers/normal_sampler.rst b/docs/source/samplers/normal_sampler.rst old mode 100644 new mode 100755 diff --git a/docs/source/samplers/pixelcnn_sampler.rst b/docs/source/samplers/pixelcnn_sampler.rst old mode 100644 new mode 100755 diff --git a/docs/source/samplers/poincare_disk_sampler.rst b/docs/source/samplers/poincare_disk_sampler.rst old mode 100644 new mode 100755 diff --git a/docs/source/samplers/pythae.samplers.rst b/docs/source/samplers/pythae.samplers.rst old mode 100644 new mode 100755 diff --git a/docs/source/samplers/rhvae_sampler.rst b/docs/source/samplers/rhvae_sampler.rst old mode 100644 new mode 100755 diff --git a/docs/source/samplers/twostage_sampler.rst b/docs/source/samplers/twostage_sampler.rst old mode 100644 new mode 100755 diff --git a/docs/source/samplers/unit_sphere_unif_sampler.rst b/docs/source/samplers/unit_sphere_unif_sampler.rst old mode 100644 new mode 100755 diff --git a/docs/source/samplers/vamp_sampler.rst b/docs/source/samplers/vamp_sampler.rst old mode 100644 new mode 100755 diff --git a/docs/source/trainers/pythae.trainer.rst b/docs/source/trainers/pythae.trainer.rst old mode 100644 new mode 100755 diff --git a/docs/source/trainers/pythae.trainers.adversarial_trainer.rst b/docs/source/trainers/pythae.trainers.adversarial_trainer.rst old mode 100644 new mode 100755 diff --git a/docs/source/trainers/pythae.trainers.base_trainer.rst b/docs/source/trainers/pythae.trainers.base_trainer.rst old mode 100644 new mode 100755 diff --git a/docs/source/trainers/pythae.trainers.coupled_optimizer_adversarial_trainer.rst b/docs/source/trainers/pythae.trainers.coupled_optimizer_adversarial_trainer.rst old mode 100644 new mode 100755 diff --git a/docs/source/trainers/pythae.trainers.coupled_trainer.rst b/docs/source/trainers/pythae.trainers.coupled_trainer.rst old mode 100644 new mode 100755 diff --git a/docs/source/trainers/pythae.training_callbacks.rst b/docs/source/trainers/pythae.training_callbacks.rst old mode 100644 new mode 100755 diff --git a/examples/notebooks/README.md b/examples/notebooks/README.md old mode 100644 new mode 100755 diff --git a/examples/notebooks/comet_experiment_monitoring.ipynb b/examples/notebooks/comet_experiment_monitoring.ipynb old mode 100644 new mode 100755 diff --git a/examples/notebooks/custom_dataset.ipynb b/examples/notebooks/custom_dataset.ipynb old mode 100644 new mode 100755 diff --git a/examples/notebooks/hf_hub_models_sharing.ipynb b/examples/notebooks/hf_hub_models_sharing.ipynb old mode 100644 new mode 100755 diff --git a/examples/notebooks/making_your_own_autoencoder.ipynb b/examples/notebooks/making_your_own_autoencoder.ipynb old mode 100644 new mode 100755 diff --git a/examples/notebooks/mlflow_experiment_monitoring.ipynb b/examples/notebooks/mlflow_experiment_monitoring.ipynb old mode 100644 new mode 100755 diff --git a/examples/notebooks/models_training/adversarial_ae_training.ipynb b/examples/notebooks/models_training/adversarial_ae_training.ipynb old mode 100644 new mode 100755 diff --git a/examples/notebooks/models_training/ae_training.ipynb b/examples/notebooks/models_training/ae_training.ipynb old mode 100644 new mode 100755 diff --git a/examples/notebooks/models_training/beta_tc_vae_training.ipynb b/examples/notebooks/models_training/beta_tc_vae_training.ipynb old mode 100644 new mode 100755 diff --git a/examples/notebooks/models_training/beta_vae_training.ipynb b/examples/notebooks/models_training/beta_vae_training.ipynb old mode 100644 new mode 100755 diff --git a/examples/notebooks/models_training/ciwae_training.ipynb b/examples/notebooks/models_training/ciwae_training.ipynb old mode 100644 new mode 100755 diff --git a/examples/notebooks/models_training/disentangled_beta_vae_training.ipynb b/examples/notebooks/models_training/disentangled_beta_vae_training.ipynb old mode 100644 new mode 100755 diff --git a/examples/notebooks/models_training/factor_vae_training.ipynb b/examples/notebooks/models_training/factor_vae_training.ipynb old mode 100644 new mode 100755 diff --git a/examples/notebooks/models_training/hvae_training.ipynb b/examples/notebooks/models_training/hvae_training.ipynb old mode 100644 new mode 100755 diff --git a/examples/notebooks/models_training/info_vae_training.ipynb b/examples/notebooks/models_training/info_vae_training.ipynb old mode 100644 new mode 100755 diff --git a/examples/notebooks/models_training/iwae_training.ipynb b/examples/notebooks/models_training/iwae_training.ipynb old mode 100644 new mode 100755 diff --git a/examples/notebooks/models_training/miwae_training.ipynb b/examples/notebooks/models_training/miwae_training.ipynb old mode 100644 new mode 100755 diff --git a/examples/notebooks/models_training/ms_ssim_vae_training.ipynb b/examples/notebooks/models_training/ms_ssim_vae_training.ipynb old mode 100644 new mode 100755 diff --git a/examples/notebooks/models_training/normalizing_flows_training.ipynb b/examples/notebooks/models_training/normalizing_flows_training.ipynb old mode 100644 new mode 100755 diff --git a/examples/notebooks/models_training/piwae_training.ipynb b/examples/notebooks/models_training/piwae_training.ipynb old mode 100644 new mode 100755 diff --git a/examples/notebooks/models_training/pvae_training.ipynb b/examples/notebooks/models_training/pvae_training.ipynb old mode 100644 new mode 100755 diff --git a/examples/notebooks/models_training/rae_gp_training.ipynb b/examples/notebooks/models_training/rae_gp_training.ipynb old mode 100644 new mode 100755 diff --git a/examples/notebooks/models_training/rae_l2_training.ipynb b/examples/notebooks/models_training/rae_l2_training.ipynb old mode 100644 new mode 100755 diff --git a/examples/notebooks/models_training/rhvae_training.ipynb b/examples/notebooks/models_training/rhvae_training.ipynb old mode 100644 new mode 100755 diff --git a/examples/notebooks/models_training/svae_training.ipynb b/examples/notebooks/models_training/svae_training.ipynb old mode 100644 new mode 100755 diff --git a/examples/notebooks/models_training/vae_iaf_training.ipynb b/examples/notebooks/models_training/vae_iaf_training.ipynb old mode 100644 new mode 100755 diff --git a/examples/notebooks/models_training/vae_lin_nf_training.ipynb b/examples/notebooks/models_training/vae_lin_nf_training.ipynb old mode 100644 new mode 100755 diff --git a/examples/notebooks/models_training/vae_lstm.ipynb b/examples/notebooks/models_training/vae_lstm.ipynb old mode 100644 new mode 100755 diff --git a/examples/notebooks/models_training/vae_training.ipynb b/examples/notebooks/models_training/vae_training.ipynb old mode 100644 new mode 100755 diff --git a/examples/notebooks/models_training/vaegan_training.ipynb b/examples/notebooks/models_training/vaegan_training.ipynb old mode 100644 new mode 100755 diff --git a/examples/notebooks/models_training/vamp_training.ipynb b/examples/notebooks/models_training/vamp_training.ipynb old mode 100644 new mode 100755 diff --git a/examples/notebooks/models_training/vqvae_training.ipynb b/examples/notebooks/models_training/vqvae_training.ipynb old mode 100644 new mode 100755 diff --git a/examples/notebooks/models_training/wae_training.ipynb b/examples/notebooks/models_training/wae_training.ipynb old mode 100644 new mode 100755 diff --git a/examples/notebooks/overview_notebook.ipynb b/examples/notebooks/overview_notebook.ipynb old mode 100644 new mode 100755 diff --git a/examples/notebooks/requirements.txt b/examples/notebooks/requirements.txt old mode 100644 new mode 100755 diff --git a/examples/notebooks/wandb_experiment_monitoring.ipynb b/examples/notebooks/wandb_experiment_monitoring.ipynb old mode 100644 new mode 100755 diff --git a/examples/scripts/README.md b/examples/scripts/README.md old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/binary_mnist/base_training_config.json b/examples/scripts/configs/binary_mnist/base_training_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/binary_mnist/hvae_config.json b/examples/scripts/configs/binary_mnist/hvae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/binary_mnist/iwae_config.json b/examples/scripts/configs/binary_mnist/iwae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/binary_mnist/rhvae_config.json b/examples/scripts/configs/binary_mnist/rhvae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/binary_mnist/svae_config.json b/examples/scripts/configs/binary_mnist/svae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/binary_mnist/vae_config.json b/examples/scripts/configs/binary_mnist/vae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/binary_mnist/vae_iaf_config.json b/examples/scripts/configs/binary_mnist/vae_iaf_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/binary_mnist/vae_lin_nf_config.json b/examples/scripts/configs/binary_mnist/vae_lin_nf_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/binary_mnist/vamp_config.json b/examples/scripts/configs/binary_mnist/vamp_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/celeba/aae_config.json b/examples/scripts/configs/celeba/aae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/celeba/ae_config.json b/examples/scripts/configs/celeba/ae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/celeba/base_training_config.json b/examples/scripts/configs/celeba/base_training_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/celeba/beta_tc_vae_config.json b/examples/scripts/configs/celeba/beta_tc_vae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/celeba/beta_vae_config.json b/examples/scripts/configs/celeba/beta_vae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/celeba/disentangled_beta_vae_config.json b/examples/scripts/configs/celeba/disentangled_beta_vae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/celeba/factor_vae_config.json b/examples/scripts/configs/celeba/factor_vae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/celeba/hvae_config.json b/examples/scripts/configs/celeba/hvae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/celeba/info_vae_config.json b/examples/scripts/configs/celeba/info_vae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/celeba/msssim_vae_config.json b/examples/scripts/configs/celeba/msssim_vae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/celeba/rae_gp_config.json b/examples/scripts/configs/celeba/rae_gp_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/celeba/rae_l2_config.json b/examples/scripts/configs/celeba/rae_l2_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/celeba/rhvae_config.json b/examples/scripts/configs/celeba/rhvae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/celeba/svae_config.json b/examples/scripts/configs/celeba/svae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/celeba/vae_config.json b/examples/scripts/configs/celeba/vae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/celeba/vae_iaf_config.json b/examples/scripts/configs/celeba/vae_iaf_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/celeba/vae_lin_nf_config.json b/examples/scripts/configs/celeba/vae_lin_nf_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/celeba/vaegan_config.json b/examples/scripts/configs/celeba/vaegan_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/celeba/vamp_config.json b/examples/scripts/configs/celeba/vamp_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/celeba/vqvae_config.json b/examples/scripts/configs/celeba/vqvae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/celeba/wae_config.json b/examples/scripts/configs/celeba/wae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/cifar10/ae_config.json b/examples/scripts/configs/cifar10/ae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/cifar10/base_training_config.json b/examples/scripts/configs/cifar10/base_training_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/cifar10/beta_vae_config.json b/examples/scripts/configs/cifar10/beta_vae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/cifar10/hvae_config.json b/examples/scripts/configs/cifar10/hvae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/cifar10/info_vae_config.json b/examples/scripts/configs/cifar10/info_vae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/cifar10/rae_gp_config.json b/examples/scripts/configs/cifar10/rae_gp_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/cifar10/rae_l2_config.json b/examples/scripts/configs/cifar10/rae_l2_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/cifar10/rhvae_config.json b/examples/scripts/configs/cifar10/rhvae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/cifar10/vae_config.json b/examples/scripts/configs/cifar10/vae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/cifar10/vamp_config.json b/examples/scripts/configs/cifar10/vamp_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/cifar10/vqvae_config.json b/examples/scripts/configs/cifar10/vqvae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/cifar10/wae_config.json b/examples/scripts/configs/cifar10/wae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/dsprites/beta_tc_vae/base_training_config.json b/examples/scripts/configs/dsprites/beta_tc_vae/base_training_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/dsprites/beta_tc_vae/beta_tc_vae_config.json b/examples/scripts/configs/dsprites/beta_tc_vae/beta_tc_vae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/dsprites/factorvae/base_training_config.json b/examples/scripts/configs/dsprites/factorvae/base_training_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/dsprites/factorvae/factorvae_config.json b/examples/scripts/configs/dsprites/factorvae/factorvae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/mnist/aae_config.json b/examples/scripts/configs/mnist/aae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/mnist/ae_config.json b/examples/scripts/configs/mnist/ae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/mnist/base_training_config.json b/examples/scripts/configs/mnist/base_training_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/mnist/beta_tc_vae_config.json b/examples/scripts/configs/mnist/beta_tc_vae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/mnist/beta_vae_config.json b/examples/scripts/configs/mnist/beta_vae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/mnist/disentangled_beta_vae_config.json b/examples/scripts/configs/mnist/disentangled_beta_vae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/mnist/factor_vae_config.json b/examples/scripts/configs/mnist/factor_vae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/mnist/hvae_config.json b/examples/scripts/configs/mnist/hvae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/mnist/info_vae_config.json b/examples/scripts/configs/mnist/info_vae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/mnist/iwae_config.json b/examples/scripts/configs/mnist/iwae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/mnist/msssim_vae_config.json b/examples/scripts/configs/mnist/msssim_vae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/mnist/rae_gp_config.json b/examples/scripts/configs/mnist/rae_gp_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/mnist/rae_l2_config.json b/examples/scripts/configs/mnist/rae_l2_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/mnist/rhvae_config.json b/examples/scripts/configs/mnist/rhvae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/mnist/svae_config.json b/examples/scripts/configs/mnist/svae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/mnist/vae_config.json b/examples/scripts/configs/mnist/vae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/mnist/vae_iaf_config.json b/examples/scripts/configs/mnist/vae_iaf_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/mnist/vae_lin_nf_config.json b/examples/scripts/configs/mnist/vae_lin_nf_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/mnist/vaegan_config.json b/examples/scripts/configs/mnist/vaegan_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/mnist/vamp_config.json b/examples/scripts/configs/mnist/vamp_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/mnist/vqvae_config.json b/examples/scripts/configs/mnist/vqvae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/configs/mnist/wae_config.json b/examples/scripts/configs/mnist/wae_config.json old mode 100644 new mode 100755 diff --git a/examples/scripts/custom_nn.py b/examples/scripts/custom_nn.py old mode 100644 new mode 100755 diff --git a/examples/scripts/distributed_training_ffhq.py b/examples/scripts/distributed_training_ffhq.py old mode 100644 new mode 100755 diff --git a/examples/scripts/distributed_training_imagenet.py b/examples/scripts/distributed_training_imagenet.py old mode 100644 new mode 100755 diff --git a/examples/scripts/distributed_training_mnist.py b/examples/scripts/distributed_training_mnist.py old mode 100644 new mode 100755 diff --git a/examples/scripts/reproducibility/README.md b/examples/scripts/reproducibility/README.md old mode 100644 new mode 100755 diff --git a/examples/scripts/reproducibility/aae.py b/examples/scripts/reproducibility/aae.py old mode 100644 new mode 100755 diff --git a/examples/scripts/reproducibility/betatcvae.py b/examples/scripts/reproducibility/betatcvae.py old mode 100644 new mode 100755 diff --git a/examples/scripts/reproducibility/ciwae.py b/examples/scripts/reproducibility/ciwae.py old mode 100644 new mode 100755 diff --git a/examples/scripts/reproducibility/hvae.py b/examples/scripts/reproducibility/hvae.py old mode 100644 new mode 100755 diff --git a/examples/scripts/reproducibility/iwae.py b/examples/scripts/reproducibility/iwae.py old mode 100644 new mode 100755 diff --git a/examples/scripts/reproducibility/miwae.py b/examples/scripts/reproducibility/miwae.py old mode 100644 new mode 100755 diff --git a/examples/scripts/reproducibility/piwae.py b/examples/scripts/reproducibility/piwae.py old mode 100644 new mode 100755 diff --git a/examples/scripts/reproducibility/pvae.py b/examples/scripts/reproducibility/pvae.py old mode 100644 new mode 100755 diff --git a/examples/scripts/reproducibility/rae_gp.py b/examples/scripts/reproducibility/rae_gp.py old mode 100644 new mode 100755 diff --git a/examples/scripts/reproducibility/rae_l2.py b/examples/scripts/reproducibility/rae_l2.py old mode 100644 new mode 100755 diff --git a/examples/scripts/reproducibility/svae.py b/examples/scripts/reproducibility/svae.py old mode 100644 new mode 100755 diff --git a/examples/scripts/reproducibility/vae.py b/examples/scripts/reproducibility/vae.py old mode 100644 new mode 100755 diff --git a/examples/scripts/reproducibility/vamp.py b/examples/scripts/reproducibility/vamp.py old mode 100644 new mode 100755 diff --git a/examples/scripts/reproducibility/wae.py b/examples/scripts/reproducibility/wae.py old mode 100644 new mode 100755 diff --git a/examples/scripts/training.py b/examples/scripts/training.py old mode 100644 new mode 100755 diff --git a/examples/showcases/aae_normal_sampling_celeba.png b/examples/showcases/aae_normal_sampling_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/aae_normal_sampling_mnist.png b/examples/showcases/aae_normal_sampling_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/aae_reconstruction_celeba.png b/examples/showcases/aae_reconstruction_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/aae_reconstruction_mnist.png b/examples/showcases/aae_reconstruction_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/ae_gmm_sampling_celeba.png b/examples/showcases/ae_gmm_sampling_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/ae_gmm_sampling_mnist.png b/examples/showcases/ae_gmm_sampling_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/ae_normal_sampling_celeba.png b/examples/showcases/ae_normal_sampling_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/ae_normal_sampling_mnist.png b/examples/showcases/ae_normal_sampling_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/ae_reconstruction_celeba.png b/examples/showcases/ae_reconstruction_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/ae_reconstruction_mnist.png b/examples/showcases/ae_reconstruction_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/beta_tc_vae_normal_sampling_celeba.png b/examples/showcases/beta_tc_vae_normal_sampling_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/beta_tc_vae_normal_sampling_mnist.png b/examples/showcases/beta_tc_vae_normal_sampling_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/beta_tc_vae_reconstruction_celeba.png b/examples/showcases/beta_tc_vae_reconstruction_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/beta_tc_vae_reconstruction_mnist.png b/examples/showcases/beta_tc_vae_reconstruction_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/beta_vae_normal_sampling_celeba.png b/examples/showcases/beta_vae_normal_sampling_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/beta_vae_normal_sampling_mnist.png b/examples/showcases/beta_vae_normal_sampling_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/beta_vae_reconstruction_celeba.png b/examples/showcases/beta_vae_reconstruction_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/beta_vae_reconstruction_mnist.png b/examples/showcases/beta_vae_reconstruction_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/disentangled_beta_vae_normal_sampling_celeba.png b/examples/showcases/disentangled_beta_vae_normal_sampling_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/disentangled_beta_vae_normal_sampling_mnist.png b/examples/showcases/disentangled_beta_vae_normal_sampling_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/disentangled_beta_vae_reconstruction_celeba.png b/examples/showcases/disentangled_beta_vae_reconstruction_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/disentangled_beta_vae_reconstruction_mnist.png b/examples/showcases/disentangled_beta_vae_reconstruction_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/eval_reconstruction_celeba.png b/examples/showcases/eval_reconstruction_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/eval_reconstruction_mnist.png b/examples/showcases/eval_reconstruction_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/factor_vae_normal_sampling_celeba.png b/examples/showcases/factor_vae_normal_sampling_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/factor_vae_normal_sampling_mnist.png b/examples/showcases/factor_vae_normal_sampling_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/factor_vae_reconstruction_celeba.png b/examples/showcases/factor_vae_reconstruction_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/factor_vae_reconstruction_mnist.png b/examples/showcases/factor_vae_reconstruction_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/hvae_gmm_sampling_mnist.png b/examples/showcases/hvae_gmm_sampling_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/hvae_normal_sampling_celeba.png b/examples/showcases/hvae_normal_sampling_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/hvae_normal_sampling_mnist.png b/examples/showcases/hvae_normal_sampling_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/hvae_reconstruction_celeba.png b/examples/showcases/hvae_reconstruction_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/hvae_reconstruction_mnist.png b/examples/showcases/hvae_reconstruction_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/hvvae_gmm_sampling_mnist.png b/examples/showcases/hvvae_gmm_sampling_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/hvvae_normal_sampling_mnist.png b/examples/showcases/hvvae_normal_sampling_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/infovae_normal_sampling_celeba.png b/examples/showcases/infovae_normal_sampling_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/infovae_normal_sampling_mnist.png b/examples/showcases/infovae_normal_sampling_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/infovae_reconstruction_celeba.png b/examples/showcases/infovae_reconstruction_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/infovae_reconstruction_mnist.png b/examples/showcases/infovae_reconstruction_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/iwae_gmm_sampling_mnist.png b/examples/showcases/iwae_gmm_sampling_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/iwae_normal_sampling_celeba.png b/examples/showcases/iwae_normal_sampling_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/iwae_normal_sampling_mnist.png b/examples/showcases/iwae_normal_sampling_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/iwae_reconstruction_celeba.png b/examples/showcases/iwae_reconstruction_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/iwae_reconstruction_mnist.png b/examples/showcases/iwae_reconstruction_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/msssim_vae_normal_sampling_celeba.png b/examples/showcases/msssim_vae_normal_sampling_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/msssim_vae_normal_sampling_mnist.png b/examples/showcases/msssim_vae_normal_sampling_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/msssim_vae_reconstruction_celeba.png b/examples/showcases/msssim_vae_reconstruction_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/msssim_vae_reconstruction_mnist.png b/examples/showcases/msssim_vae_reconstruction_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/rae_gp_gmm_sampling_celeba.png b/examples/showcases/rae_gp_gmm_sampling_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/rae_gp_gmm_sampling_mnist.png b/examples/showcases/rae_gp_gmm_sampling_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/rae_gp_normal_sampling_celeba.png b/examples/showcases/rae_gp_normal_sampling_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/rae_gp_normal_sampling_mnist.png b/examples/showcases/rae_gp_normal_sampling_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/rae_gp_reconstruction_celeba.png b/examples/showcases/rae_gp_reconstruction_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/rae_gp_reconstruction_mnist.png b/examples/showcases/rae_gp_reconstruction_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/rae_l2_gmm_sampling_celeba.png b/examples/showcases/rae_l2_gmm_sampling_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/rae_l2_gmm_sampling_mnist.png b/examples/showcases/rae_l2_gmm_sampling_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/rae_l2_normal_sampling_celeba.png b/examples/showcases/rae_l2_normal_sampling_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/rae_l2_normal_sampling_mnist.png b/examples/showcases/rae_l2_normal_sampling_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/rae_l2_reconstruction_celeba.png b/examples/showcases/rae_l2_reconstruction_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/rae_l2_reconstruction_mnist.png b/examples/showcases/rae_l2_reconstruction_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/rhvae_reconstruction_celeba.png b/examples/showcases/rhvae_reconstruction_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/rhvae_reconstruction_mnist.png b/examples/showcases/rhvae_reconstruction_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/rhvae_rhvae_sampling_celeba.png b/examples/showcases/rhvae_rhvae_sampling_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/rhvae_rhvae_sampling_mnist.png b/examples/showcases/rhvae_rhvae_sampling_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/svae_hypersphere_uniform_sampling_celeba.png b/examples/showcases/svae_hypersphere_uniform_sampling_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/svae_hypersphere_uniform_sampling_mnist.png b/examples/showcases/svae_hypersphere_uniform_sampling_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/svae_reconstruction_celeba.png b/examples/showcases/svae_reconstruction_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/svae_reconstruction_mnist.png b/examples/showcases/svae_reconstruction_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/vae_gmm_sampling_celeba.png b/examples/showcases/vae_gmm_sampling_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/vae_gmm_sampling_mnist.png b/examples/showcases/vae_gmm_sampling_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/vae_iaf_normal_sampling_celeba.png b/examples/showcases/vae_iaf_normal_sampling_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/vae_iaf_normal_sampling_mnist.png b/examples/showcases/vae_iaf_normal_sampling_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/vae_iaf_reconstruction_celeba.png b/examples/showcases/vae_iaf_reconstruction_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/vae_iaf_reconstruction_mnist.png b/examples/showcases/vae_iaf_reconstruction_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/vae_lin_nf_normal_sampling_celeba.png b/examples/showcases/vae_lin_nf_normal_sampling_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/vae_lin_nf_normal_sampling_mnist.png b/examples/showcases/vae_lin_nf_normal_sampling_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/vae_lin_nf_reconstruction_celeba.png b/examples/showcases/vae_lin_nf_reconstruction_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/vae_lin_nf_reconstruction_mnist.png b/examples/showcases/vae_lin_nf_reconstruction_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/vae_maf_sampling_celeba.png b/examples/showcases/vae_maf_sampling_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/vae_maf_sampling_mnist.png b/examples/showcases/vae_maf_sampling_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/vae_normal_sampling_celeba.png b/examples/showcases/vae_normal_sampling_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/vae_normal_sampling_mnist.png b/examples/showcases/vae_normal_sampling_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/vae_reconstruction_celeba.png b/examples/showcases/vae_reconstruction_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/vae_reconstruction_mnist.png b/examples/showcases/vae_reconstruction_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/vae_second_stage_sampling_celeba.png b/examples/showcases/vae_second_stage_sampling_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/vae_second_stage_sampling_mnist.png b/examples/showcases/vae_second_stage_sampling_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/vaegan_normal_sampling_celeba.png b/examples/showcases/vaegan_normal_sampling_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/vaegan_normal_sampling_mnist.png b/examples/showcases/vaegan_normal_sampling_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/vaegan_reconstruction_celeba.png b/examples/showcases/vaegan_reconstruction_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/vaegan_reconstruction_mnist.png b/examples/showcases/vaegan_reconstruction_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/vamp_reconstruction_celeba.png b/examples/showcases/vamp_reconstruction_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/vamp_reconstruction_mnist.png b/examples/showcases/vamp_reconstruction_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/vamp_vamp_sampling_celeba.png b/examples/showcases/vamp_vamp_sampling_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/vamp_vamp_sampling_mnist.png b/examples/showcases/vamp_vamp_sampling_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/vqvae_maf_sampling_celeba.png b/examples/showcases/vqvae_maf_sampling_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/vqvae_maf_sampling_mnist.png b/examples/showcases/vqvae_maf_sampling_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/vqvae_pixelcnn_sampling_mnist.png b/examples/showcases/vqvae_pixelcnn_sampling_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/vqvae_reconstruction_celeba.png b/examples/showcases/vqvae_reconstruction_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/vqvae_reconstruction_mnist.png b/examples/showcases/vqvae_reconstruction_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/wae_gmm_sampling_celeba.png b/examples/showcases/wae_gmm_sampling_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/wae_gmm_sampling_mnist.png b/examples/showcases/wae_gmm_sampling_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/wae_normal_sampling_celeba.png b/examples/showcases/wae_normal_sampling_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/wae_normal_sampling_mnist.png b/examples/showcases/wae_normal_sampling_mnist.png old mode 100644 new mode 100755 diff --git a/examples/showcases/wae_reconstruction_celeba.png b/examples/showcases/wae_reconstruction_celeba.png old mode 100644 new mode 100755 diff --git a/examples/showcases/wae_reconstruction_mnist.png b/examples/showcases/wae_reconstruction_mnist.png old mode 100644 new mode 100755 diff --git a/pyproject.toml b/pyproject.toml old mode 100644 new mode 100755 diff --git a/requirements.txt b/requirements.txt old mode 100644 new mode 100755 diff --git a/setup.cfg b/setup.cfg old mode 100644 new mode 100755 diff --git a/setup.py b/setup.py old mode 100644 new mode 100755 diff --git a/src/pythae/__init__.py b/src/pythae/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/config.py b/src/pythae/config.py old mode 100644 new mode 100755 diff --git a/src/pythae/customexception.py b/src/pythae/customexception.py old mode 100644 new mode 100755 diff --git a/src/pythae/data/__init__.py b/src/pythae/data/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/data/datasets.py b/src/pythae/data/datasets.py old mode 100644 new mode 100755 diff --git a/src/pythae/data/preprocessors.py b/src/pythae/data/preprocessors.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/__init__.py b/src/pythae/models/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/adversarial_ae/__init__.py b/src/pythae/models/adversarial_ae/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/adversarial_ae/adversarial_ae_config.py b/src/pythae/models/adversarial_ae/adversarial_ae_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/adversarial_ae/adversarial_ae_model.py b/src/pythae/models/adversarial_ae/adversarial_ae_model.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/ae/__init__.py b/src/pythae/models/ae/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/ae/ae_config.py b/src/pythae/models/ae/ae_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/ae/ae_model.py b/src/pythae/models/ae/ae_model.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/auto_model/__init__.py b/src/pythae/models/auto_model/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/auto_model/auto_config.py b/src/pythae/models/auto_model/auto_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/auto_model/auto_model.py b/src/pythae/models/auto_model/auto_model.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/base/__init__.py b/src/pythae/models/base/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/base/base_config.py b/src/pythae/models/base/base_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/base/base_model.py b/src/pythae/models/base/base_model.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/base/base_utils.py b/src/pythae/models/base/base_utils.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/beta_tc_vae/__init__.py b/src/pythae/models/beta_tc_vae/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/beta_tc_vae/beta_tc_vae_config.py b/src/pythae/models/beta_tc_vae/beta_tc_vae_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/beta_tc_vae/beta_tc_vae_model.py b/src/pythae/models/beta_tc_vae/beta_tc_vae_model.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/beta_vae/__init__.py b/src/pythae/models/beta_vae/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/beta_vae/beta_vae_config.py b/src/pythae/models/beta_vae/beta_vae_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/beta_vae/beta_vae_model.py b/src/pythae/models/beta_vae/beta_vae_model.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/ciwae/__init__.py b/src/pythae/models/ciwae/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/ciwae/ciwae_config.py b/src/pythae/models/ciwae/ciwae_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/ciwae/ciwae_model.py b/src/pythae/models/ciwae/ciwae_model.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/disentangled_beta_vae/__init__.py b/src/pythae/models/disentangled_beta_vae/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/disentangled_beta_vae/disentangled_beta_vae_config.py b/src/pythae/models/disentangled_beta_vae/disentangled_beta_vae_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/disentangled_beta_vae/disentangled_beta_vae_model.py b/src/pythae/models/disentangled_beta_vae/disentangled_beta_vae_model.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/factor_vae/__init__.py b/src/pythae/models/factor_vae/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/factor_vae/factor_vae_config.py b/src/pythae/models/factor_vae/factor_vae_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/factor_vae/factor_vae_model.py b/src/pythae/models/factor_vae/factor_vae_model.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/factor_vae/factor_vae_utils.py b/src/pythae/models/factor_vae/factor_vae_utils.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/hvae/__init__.py b/src/pythae/models/hvae/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/hvae/hvae_config.py b/src/pythae/models/hvae/hvae_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/hvae/hvae_model.py b/src/pythae/models/hvae/hvae_model.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/info_vae/__init__.py b/src/pythae/models/info_vae/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/info_vae/info_vae_config.py b/src/pythae/models/info_vae/info_vae_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/info_vae/info_vae_model.py b/src/pythae/models/info_vae/info_vae_model.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/iwae/__init__.py b/src/pythae/models/iwae/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/iwae/iwae_config.py b/src/pythae/models/iwae/iwae_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/iwae/iwae_model.py b/src/pythae/models/iwae/iwae_model.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/miwae/__init__.py b/src/pythae/models/miwae/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/miwae/miwae_config.py b/src/pythae/models/miwae/miwae_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/miwae/miwae_model.py b/src/pythae/models/miwae/miwae_model.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/msssim_vae/__init__.py b/src/pythae/models/msssim_vae/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/msssim_vae/msssim_vae_config.py b/src/pythae/models/msssim_vae/msssim_vae_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/msssim_vae/msssim_vae_model.py b/src/pythae/models/msssim_vae/msssim_vae_model.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/msssim_vae/msssim_vae_utils.py b/src/pythae/models/msssim_vae/msssim_vae_utils.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/nn/__init__.py b/src/pythae/models/nn/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/nn/base_architectures.py b/src/pythae/models/nn/base_architectures.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/nn/benchmarks/__init__.py b/src/pythae/models/nn/benchmarks/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/nn/benchmarks/celeba/__init__.py b/src/pythae/models/nn/benchmarks/celeba/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/nn/benchmarks/celeba/convnets.py b/src/pythae/models/nn/benchmarks/celeba/convnets.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/nn/benchmarks/celeba/resnets.py b/src/pythae/models/nn/benchmarks/celeba/resnets.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/nn/benchmarks/cifar/__init__.py b/src/pythae/models/nn/benchmarks/cifar/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/nn/benchmarks/cifar/convnets.py b/src/pythae/models/nn/benchmarks/cifar/convnets.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/nn/benchmarks/cifar/resnets.py b/src/pythae/models/nn/benchmarks/cifar/resnets.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/nn/benchmarks/mnist/__init__.py b/src/pythae/models/nn/benchmarks/mnist/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/nn/benchmarks/mnist/convnets.py b/src/pythae/models/nn/benchmarks/mnist/convnets.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/nn/benchmarks/mnist/resnets.py b/src/pythae/models/nn/benchmarks/mnist/resnets.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/nn/benchmarks/utils.py b/src/pythae/models/nn/benchmarks/utils.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/nn/default_architectures.py b/src/pythae/models/nn/default_architectures.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/normalizing_flows/__init__.py b/src/pythae/models/normalizing_flows/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/normalizing_flows/base/__init__.py b/src/pythae/models/normalizing_flows/base/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/normalizing_flows/base/base_nf_config.py b/src/pythae/models/normalizing_flows/base/base_nf_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/normalizing_flows/base/base_nf_model.py b/src/pythae/models/normalizing_flows/base/base_nf_model.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/normalizing_flows/iaf/__init__.py b/src/pythae/models/normalizing_flows/iaf/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/normalizing_flows/iaf/iaf_config.py b/src/pythae/models/normalizing_flows/iaf/iaf_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/normalizing_flows/iaf/iaf_model.py b/src/pythae/models/normalizing_flows/iaf/iaf_model.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/normalizing_flows/layers.py b/src/pythae/models/normalizing_flows/layers.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/normalizing_flows/made/__init__.py b/src/pythae/models/normalizing_flows/made/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/normalizing_flows/made/made_config.py b/src/pythae/models/normalizing_flows/made/made_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/normalizing_flows/made/made_model.py b/src/pythae/models/normalizing_flows/made/made_model.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/normalizing_flows/maf/__init__.py b/src/pythae/models/normalizing_flows/maf/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/normalizing_flows/maf/maf_config.py b/src/pythae/models/normalizing_flows/maf/maf_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/normalizing_flows/maf/maf_model.py b/src/pythae/models/normalizing_flows/maf/maf_model.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/normalizing_flows/pixelcnn/__init__.py b/src/pythae/models/normalizing_flows/pixelcnn/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/normalizing_flows/pixelcnn/pixelcnn_config.py b/src/pythae/models/normalizing_flows/pixelcnn/pixelcnn_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/normalizing_flows/pixelcnn/pixelcnn_model.py b/src/pythae/models/normalizing_flows/pixelcnn/pixelcnn_model.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/normalizing_flows/pixelcnn/utils.py b/src/pythae/models/normalizing_flows/pixelcnn/utils.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/normalizing_flows/planar_flow/__init__.py b/src/pythae/models/normalizing_flows/planar_flow/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/normalizing_flows/planar_flow/planar_flow_config.py b/src/pythae/models/normalizing_flows/planar_flow/planar_flow_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/normalizing_flows/planar_flow/planar_flow_model.py b/src/pythae/models/normalizing_flows/planar_flow/planar_flow_model.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/normalizing_flows/radial_flow/__init__.py b/src/pythae/models/normalizing_flows/radial_flow/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/normalizing_flows/radial_flow/radial_flow_config.py b/src/pythae/models/normalizing_flows/radial_flow/radial_flow_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/normalizing_flows/radial_flow/radial_flow_model.py b/src/pythae/models/normalizing_flows/radial_flow/radial_flow_model.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/piwae/__init__.py b/src/pythae/models/piwae/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/piwae/piwae_config.py b/src/pythae/models/piwae/piwae_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/piwae/piwae_model.py b/src/pythae/models/piwae/piwae_model.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/pvae/__init__.py b/src/pythae/models/pvae/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/pvae/pvae_config.py b/src/pythae/models/pvae/pvae_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/pvae/pvae_model.py b/src/pythae/models/pvae/pvae_model.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/pvae/pvae_utils.py b/src/pythae/models/pvae/pvae_utils.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/rae_gp/__init__.py b/src/pythae/models/rae_gp/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/rae_gp/rae_gp_config.py b/src/pythae/models/rae_gp/rae_gp_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/rae_gp/rae_gp_model.py b/src/pythae/models/rae_gp/rae_gp_model.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/rae_l2/__init__.py b/src/pythae/models/rae_l2/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/rae_l2/rae_l2_config.py b/src/pythae/models/rae_l2/rae_l2_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/rae_l2/rae_l2_model.py b/src/pythae/models/rae_l2/rae_l2_model.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/rhvae/__init__.py b/src/pythae/models/rhvae/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/rhvae/rhvae_config.py b/src/pythae/models/rhvae/rhvae_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/rhvae/rhvae_model.py b/src/pythae/models/rhvae/rhvae_model.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/rhvae/rhvae_utils.py b/src/pythae/models/rhvae/rhvae_utils.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/svae/__init__.py b/src/pythae/models/svae/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/svae/svae_config.py b/src/pythae/models/svae/svae_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/svae/svae_model.py b/src/pythae/models/svae/svae_model.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/svae/svae_utils.py b/src/pythae/models/svae/svae_utils.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/vae/__init__.py b/src/pythae/models/vae/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/vae/vae_config.py b/src/pythae/models/vae/vae_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/vae/vae_model.py b/src/pythae/models/vae/vae_model.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/vae_gan/__init__.py b/src/pythae/models/vae_gan/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/vae_gan/vae_gan_config.py b/src/pythae/models/vae_gan/vae_gan_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/vae_gan/vae_gan_model.py b/src/pythae/models/vae_gan/vae_gan_model.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/vae_iaf/__init__.py b/src/pythae/models/vae_iaf/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/vae_iaf/vae_iaf_config.py b/src/pythae/models/vae_iaf/vae_iaf_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/vae_iaf/vae_iaf_model.py b/src/pythae/models/vae_iaf/vae_iaf_model.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/vae_lin_nf/__init__.py b/src/pythae/models/vae_lin_nf/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/vae_lin_nf/vae_lin_nf_config.py b/src/pythae/models/vae_lin_nf/vae_lin_nf_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/vae_lin_nf/vae_lin_nf_model.py b/src/pythae/models/vae_lin_nf/vae_lin_nf_model.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/vamp/__init__.py b/src/pythae/models/vamp/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/vamp/vamp_config.py b/src/pythae/models/vamp/vamp_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/vamp/vamp_model.py b/src/pythae/models/vamp/vamp_model.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/vq_vae/__init__.py b/src/pythae/models/vq_vae/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/vq_vae/vq_vae_config.py b/src/pythae/models/vq_vae/vq_vae_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/vq_vae/vq_vae_model.py b/src/pythae/models/vq_vae/vq_vae_model.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/vq_vae/vq_vae_utils.py b/src/pythae/models/vq_vae/vq_vae_utils.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/wae_mmd/__init__.py b/src/pythae/models/wae_mmd/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/wae_mmd/wae_mmd_config.py b/src/pythae/models/wae_mmd/wae_mmd_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/models/wae_mmd/wae_mmd_model.py b/src/pythae/models/wae_mmd/wae_mmd_model.py old mode 100644 new mode 100755 diff --git a/src/pythae/pipelines/__init__.py b/src/pythae/pipelines/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/pipelines/base_pipeline.py b/src/pythae/pipelines/base_pipeline.py old mode 100644 new mode 100755 diff --git a/src/pythae/pipelines/generation.py b/src/pythae/pipelines/generation.py old mode 100644 new mode 100755 diff --git a/src/pythae/pipelines/pipeline_utils.py b/src/pythae/pipelines/pipeline_utils.py old mode 100644 new mode 100755 diff --git a/src/pythae/pipelines/training.py b/src/pythae/pipelines/training.py old mode 100644 new mode 100755 diff --git a/src/pythae/py.typed b/src/pythae/py.typed old mode 100644 new mode 100755 diff --git a/src/pythae/samplers/__init__.py b/src/pythae/samplers/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/samplers/base/__init__.py b/src/pythae/samplers/base/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/samplers/base/base_sampler.py b/src/pythae/samplers/base/base_sampler.py old mode 100644 new mode 100755 diff --git a/src/pythae/samplers/base/base_sampler_config.py b/src/pythae/samplers/base/base_sampler_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/samplers/gaussian_mixture/__init__.py b/src/pythae/samplers/gaussian_mixture/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/samplers/gaussian_mixture/gaussian_mixture_config.py b/src/pythae/samplers/gaussian_mixture/gaussian_mixture_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/samplers/gaussian_mixture/gaussian_mixture_sampler.py b/src/pythae/samplers/gaussian_mixture/gaussian_mixture_sampler.py old mode 100644 new mode 100755 diff --git a/src/pythae/samplers/hypersphere_uniform_sampler/__init__.py b/src/pythae/samplers/hypersphere_uniform_sampler/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/samplers/hypersphere_uniform_sampler/hypersphere_uniform_config.py b/src/pythae/samplers/hypersphere_uniform_sampler/hypersphere_uniform_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/samplers/hypersphere_uniform_sampler/hypersphere_uniform_sampler.py b/src/pythae/samplers/hypersphere_uniform_sampler/hypersphere_uniform_sampler.py old mode 100644 new mode 100755 diff --git a/src/pythae/samplers/iaf_sampler/__init__.py b/src/pythae/samplers/iaf_sampler/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/samplers/iaf_sampler/iaf_sampler.py b/src/pythae/samplers/iaf_sampler/iaf_sampler.py old mode 100644 new mode 100755 diff --git a/src/pythae/samplers/iaf_sampler/iaf_sampler_config.py b/src/pythae/samplers/iaf_sampler/iaf_sampler_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/samplers/maf_sampler/__init__.py b/src/pythae/samplers/maf_sampler/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/samplers/maf_sampler/maf_sampler.py b/src/pythae/samplers/maf_sampler/maf_sampler.py old mode 100644 new mode 100755 diff --git a/src/pythae/samplers/maf_sampler/maf_sampler_config.py b/src/pythae/samplers/maf_sampler/maf_sampler_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/samplers/manifold_sampler/__init__.py b/src/pythae/samplers/manifold_sampler/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/samplers/manifold_sampler/rhvae_sampler.py b/src/pythae/samplers/manifold_sampler/rhvae_sampler.py old mode 100644 new mode 100755 diff --git a/src/pythae/samplers/manifold_sampler/rhvae_sampler_config.py b/src/pythae/samplers/manifold_sampler/rhvae_sampler_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/samplers/normal_sampling/__init__.py b/src/pythae/samplers/normal_sampling/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/samplers/normal_sampling/normal_config.py b/src/pythae/samplers/normal_sampling/normal_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/samplers/normal_sampling/normal_sampler.py b/src/pythae/samplers/normal_sampling/normal_sampler.py old mode 100644 new mode 100755 diff --git a/src/pythae/samplers/pixelcnn_sampler/__init__.py b/src/pythae/samplers/pixelcnn_sampler/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/samplers/pixelcnn_sampler/pixelcnn_sampler.py b/src/pythae/samplers/pixelcnn_sampler/pixelcnn_sampler.py old mode 100644 new mode 100755 diff --git a/src/pythae/samplers/pixelcnn_sampler/pixelcnn_sampler_config.py b/src/pythae/samplers/pixelcnn_sampler/pixelcnn_sampler_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/samplers/pvae_sampler/__init__.py b/src/pythae/samplers/pvae_sampler/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/samplers/pvae_sampler/pvae_sampler.py b/src/pythae/samplers/pvae_sampler/pvae_sampler.py old mode 100644 new mode 100755 diff --git a/src/pythae/samplers/pvae_sampler/pvae_sampler_config.py b/src/pythae/samplers/pvae_sampler/pvae_sampler_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/samplers/two_stage_vae_sampler/__init__.py b/src/pythae/samplers/two_stage_vae_sampler/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/samplers/two_stage_vae_sampler/two_stage_sampler.py b/src/pythae/samplers/two_stage_vae_sampler/two_stage_sampler.py old mode 100644 new mode 100755 diff --git a/src/pythae/samplers/two_stage_vae_sampler/two_stage_sampler_config.py b/src/pythae/samplers/two_stage_vae_sampler/two_stage_sampler_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/samplers/vamp_sampler/__init__.py b/src/pythae/samplers/vamp_sampler/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/samplers/vamp_sampler/vamp_sampler.py b/src/pythae/samplers/vamp_sampler/vamp_sampler.py old mode 100644 new mode 100755 diff --git a/src/pythae/samplers/vamp_sampler/vamp_sampler_config.py b/src/pythae/samplers/vamp_sampler/vamp_sampler_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/trainers/__init__.py b/src/pythae/trainers/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/trainers/adversarial_trainer/__init__.py b/src/pythae/trainers/adversarial_trainer/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/trainers/adversarial_trainer/adversarial_trainer.py b/src/pythae/trainers/adversarial_trainer/adversarial_trainer.py old mode 100644 new mode 100755 diff --git a/src/pythae/trainers/adversarial_trainer/adversarial_trainer_config.py b/src/pythae/trainers/adversarial_trainer/adversarial_trainer_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/trainers/base_trainer/__init__.py b/src/pythae/trainers/base_trainer/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/trainers/base_trainer/base_trainer.py b/src/pythae/trainers/base_trainer/base_trainer.py old mode 100644 new mode 100755 diff --git a/src/pythae/trainers/base_trainer/base_training_config.py b/src/pythae/trainers/base_trainer/base_training_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/trainers/coupled_optimizer_adversarial_trainer/__init__.py b/src/pythae/trainers/coupled_optimizer_adversarial_trainer/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/trainers/coupled_optimizer_adversarial_trainer/coupled_optimizer_adversarial_trainer.py b/src/pythae/trainers/coupled_optimizer_adversarial_trainer/coupled_optimizer_adversarial_trainer.py old mode 100644 new mode 100755 diff --git a/src/pythae/trainers/coupled_optimizer_adversarial_trainer/coupled_optimizer_adversarial_trainer_config.py b/src/pythae/trainers/coupled_optimizer_adversarial_trainer/coupled_optimizer_adversarial_trainer_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/trainers/coupled_optimizer_trainer/__init__.py b/src/pythae/trainers/coupled_optimizer_trainer/__init__.py old mode 100644 new mode 100755 diff --git a/src/pythae/trainers/coupled_optimizer_trainer/coupled_optimizer_trainer.py b/src/pythae/trainers/coupled_optimizer_trainer/coupled_optimizer_trainer.py old mode 100644 new mode 100755 diff --git a/src/pythae/trainers/coupled_optimizer_trainer/coupled_optimizer_trainer_config.py b/src/pythae/trainers/coupled_optimizer_trainer/coupled_optimizer_trainer_config.py old mode 100644 new mode 100755 diff --git a/src/pythae/trainers/trainer_utils.py b/src/pythae/trainers/trainer_utils.py old mode 100644 new mode 100755 diff --git a/src/pythae/trainers/training_callbacks.py b/src/pythae/trainers/training_callbacks.py old mode 100644 new mode 100755 diff --git a/tests/README.md b/tests/README.md old mode 100644 new mode 100755 diff --git a/tests/__init__.py b/tests/__init__.py old mode 100644 new mode 100755 diff --git a/tests/conftest.py b/tests/conftest.py old mode 100644 new mode 100755 diff --git a/tests/data/baseAE/configs/corrupted_model_config.json b/tests/data/baseAE/configs/corrupted_model_config.json old mode 100644 new mode 100755 diff --git a/tests/data/baseAE/configs/generation_config00.json b/tests/data/baseAE/configs/generation_config00.json old mode 100644 new mode 100755 diff --git a/tests/data/baseAE/configs/model_config00.json b/tests/data/baseAE/configs/model_config00.json old mode 100644 new mode 100755 diff --git a/tests/data/baseAE/configs/not_json_file.md b/tests/data/baseAE/configs/not_json_file.md old mode 100644 new mode 100755 diff --git a/tests/data/baseAE/configs/training_config00.json b/tests/data/baseAE/configs/training_config00.json old mode 100644 new mode 100755 diff --git a/tests/data/corrupted_config/model_config.json b/tests/data/corrupted_config/model_config.json old mode 100644 new mode 100755 diff --git a/tests/data/custom_architectures.py b/tests/data/custom_architectures.py old mode 100644 new mode 100755 diff --git a/tests/data/loading/dummy_data_folder/example0.bmp b/tests/data/loading/dummy_data_folder/example0.bmp old mode 100644 new mode 100755 diff --git a/tests/data/loading/dummy_data_folder/example0.jpeg b/tests/data/loading/dummy_data_folder/example0.jpeg old mode 100644 new mode 100755 diff --git a/tests/data/loading/dummy_data_folder/example0.jpg b/tests/data/loading/dummy_data_folder/example0.jpg old mode 100644 new mode 100755 diff --git a/tests/data/loading/dummy_data_folder/example0.png b/tests/data/loading/dummy_data_folder/example0.png old mode 100644 new mode 100755 diff --git a/tests/data/loading/dummy_data_folder/example0_downsampled_12_12.jpg b/tests/data/loading/dummy_data_folder/example0_downsampled_12_12.jpg old mode 100644 new mode 100755 diff --git a/tests/data/mnist_clean_train_dataset_sample b/tests/data/mnist_clean_train_dataset_sample old mode 100644 new mode 100755 diff --git a/tests/data/rhvae/configs/model_config00.json b/tests/data/rhvae/configs/model_config00.json old mode 100644 new mode 100755 diff --git a/tests/data/rhvae/configs/trained_model_folder/model.pt b/tests/data/rhvae/configs/trained_model_folder/model.pt old mode 100644 new mode 100755 diff --git a/tests/data/rhvae/configs/trained_model_folder/model_config.json b/tests/data/rhvae/configs/trained_model_folder/model_config.json old mode 100644 new mode 100755 diff --git a/tests/data/unnormalized_mnist_data_array b/tests/data/unnormalized_mnist_data_array old mode 100644 new mode 100755 diff --git a/tests/data/unnormalized_mnist_data_list_of_array b/tests/data/unnormalized_mnist_data_list_of_array old mode 100644 new mode 100755 diff --git a/tests/pytest.ini b/tests/pytest.ini old mode 100644 new mode 100755 diff --git a/tests/test_AE.py b/tests/test_AE.py old mode 100644 new mode 100755 diff --git a/tests/test_Adversarial_AE.py b/tests/test_Adversarial_AE.py old mode 100644 new mode 100755 diff --git a/tests/test_BetaTCVAE.py b/tests/test_BetaTCVAE.py old mode 100644 new mode 100755 diff --git a/tests/test_BetaVAE.py b/tests/test_BetaVAE.py old mode 100644 new mode 100755 diff --git a/tests/test_CIWAE.py b/tests/test_CIWAE.py old mode 100644 new mode 100755 diff --git a/tests/test_DisentangledBetaVAE.py b/tests/test_DisentangledBetaVAE.py old mode 100644 new mode 100755 diff --git a/tests/test_FactorVAE.py b/tests/test_FactorVAE.py old mode 100644 new mode 100755 diff --git a/tests/test_HVAE.py b/tests/test_HVAE.py old mode 100644 new mode 100755 diff --git a/tests/test_IAF.py b/tests/test_IAF.py old mode 100644 new mode 100755 diff --git a/tests/test_IWAE.py b/tests/test_IWAE.py old mode 100644 new mode 100755 diff --git a/tests/test_MADE.py b/tests/test_MADE.py old mode 100644 new mode 100755 diff --git a/tests/test_MAF.py b/tests/test_MAF.py old mode 100644 new mode 100755 diff --git a/tests/test_MIWAE.py b/tests/test_MIWAE.py old mode 100644 new mode 100755 diff --git a/tests/test_MSSSIMVAE.py b/tests/test_MSSSIMVAE.py old mode 100644 new mode 100755 diff --git a/tests/test_PIWAE.py b/tests/test_PIWAE.py old mode 100644 new mode 100755 diff --git a/tests/test_PixelCNN.py b/tests/test_PixelCNN.py old mode 100644 new mode 100755 diff --git a/tests/test_PoincareVAE.py b/tests/test_PoincareVAE.py old mode 100644 new mode 100755 diff --git a/tests/test_RHVAE.py b/tests/test_RHVAE.py old mode 100644 new mode 100755 diff --git a/tests/test_SVAE.py b/tests/test_SVAE.py old mode 100644 new mode 100755 diff --git a/tests/test_VAE.py b/tests/test_VAE.py old mode 100644 new mode 100755 diff --git a/tests/test_VAEGAN.py b/tests/test_VAEGAN.py old mode 100644 new mode 100755 diff --git a/tests/test_VAE_IAF.py b/tests/test_VAE_IAF.py old mode 100644 new mode 100755 diff --git a/tests/test_VAE_LinFlow.py b/tests/test_VAE_LinFlow.py old mode 100644 new mode 100755 diff --git a/tests/test_VAMP.py b/tests/test_VAMP.py old mode 100644 new mode 100755 diff --git a/tests/test_VQVAE.py b/tests/test_VQVAE.py old mode 100644 new mode 100755 diff --git a/tests/test_WAE_MMD.py b/tests/test_WAE_MMD.py old mode 100644 new mode 100755 diff --git a/tests/test_adversarial_trainer.py b/tests/test_adversarial_trainer.py old mode 100644 new mode 100755 diff --git a/tests/test_auto_model.py b/tests/test_auto_model.py old mode 100644 new mode 100755 diff --git a/tests/test_baseAE.py b/tests/test_baseAE.py old mode 100644 new mode 100755 diff --git a/tests/test_baseSampler.py b/tests/test_baseSampler.py old mode 100644 new mode 100755 diff --git a/tests/test_base_trainer.py b/tests/test_base_trainer.py old mode 100644 new mode 100755 diff --git a/tests/test_config.py b/tests/test_config.py old mode 100644 new mode 100755 diff --git a/tests/test_coupled_optimizers_adversarial_trainer.py b/tests/test_coupled_optimizers_adversarial_trainer.py old mode 100644 new mode 100755 diff --git a/tests/test_coupled_optimizers_trainer.py b/tests/test_coupled_optimizers_trainer.py old mode 100644 new mode 100755 diff --git a/tests/test_datasets.py b/tests/test_datasets.py old mode 100644 new mode 100755 diff --git a/tests/test_gaussian_mixture_sampler.py b/tests/test_gaussian_mixture_sampler.py old mode 100644 new mode 100755 diff --git a/tests/test_hypersphere_uniform_sampler.py b/tests/test_hypersphere_uniform_sampler.py old mode 100644 new mode 100755 diff --git a/tests/test_iaf_sampler.py b/tests/test_iaf_sampler.py old mode 100644 new mode 100755 diff --git a/tests/test_info_vae_mmd.py b/tests/test_info_vae_mmd.py old mode 100644 new mode 100755 diff --git a/tests/test_maf_sampler.py b/tests/test_maf_sampler.py old mode 100644 new mode 100755 diff --git a/tests/test_nn_benchmark.py b/tests/test_nn_benchmark.py old mode 100644 new mode 100755 diff --git a/tests/test_normal_sampler.py b/tests/test_normal_sampler.py old mode 100644 new mode 100755 diff --git a/tests/test_pipeline_standalone.py b/tests/test_pipeline_standalone.py old mode 100644 new mode 100755 diff --git a/tests/test_pixelcnn_sampler.py b/tests/test_pixelcnn_sampler.py old mode 100644 new mode 100755 diff --git a/tests/test_planar_flow.py b/tests/test_planar_flow.py old mode 100644 new mode 100755 diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py old mode 100644 new mode 100755 diff --git a/tests/test_pvae_sampler.py b/tests/test_pvae_sampler.py old mode 100644 new mode 100755 diff --git a/tests/test_radial_flow.py b/tests/test_radial_flow.py old mode 100644 new mode 100755 diff --git a/tests/test_rae_gp.py b/tests/test_rae_gp.py old mode 100644 new mode 100755 diff --git a/tests/test_rae_l2.py b/tests/test_rae_l2.py old mode 100644 new mode 100755 diff --git a/tests/test_rhvae_sampler.py b/tests/test_rhvae_sampler.py old mode 100644 new mode 100755 diff --git a/tests/test_training_callbacks.py b/tests/test_training_callbacks.py old mode 100644 new mode 100755 diff --git a/tests/test_two_stage_sampler.py b/tests/test_two_stage_sampler.py old mode 100644 new mode 100755 diff --git a/tests/test_vamp_sampler.py b/tests/test_vamp_sampler.py old mode 100644 new mode 100755 diff --git a/tests/your_file.jpeg b/tests/your_file.jpeg old mode 100644 new mode 100755 From ee4a8c2a218dd71fb40ed85370fb9ee9d40fff0a Mon Sep 17 00:00:00 2001 From: "soumick.chatterjee" Date: Sat, 27 May 2023 11:00:44 +0200 Subject: [PATCH 09/15] cvcnn added --- .../models/factor_vae/factor_vae_model.py | 13 ++++++++--- .../models/factor_vae/factor_vae_utils.py | 23 +++++++++++++++++++ 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/src/pythae/models/factor_vae/factor_vae_model.py b/src/pythae/models/factor_vae/factor_vae_model.py index c9aa8e01..5f7ea755 100755 --- a/src/pythae/models/factor_vae/factor_vae_model.py +++ b/src/pythae/models/factor_vae/factor_vae_model.py @@ -10,7 +10,7 @@ from ..nn import BaseDecoder, BaseDiscriminator, BaseEncoder from ..vae import VAE from .factor_vae_config import FactorVAEConfig -from .factor_vae_utils import FactorVAEDiscriminator +from .factor_vae_utils import FactorVAEDiscriminator, CVFactorVAEDiscriminator class FactorVAE(VAE): @@ -51,8 +51,11 @@ def __init__( ): VAE.__init__(self, model_config=model_config, encoder=encoder, decoder=decoder) - - self.discriminator = FactorVAEDiscriminator(latent_dim=model_config.latent_dim) + + if model_config.is_cv: + self.discriminator = CVFactorVAEDiscriminator(latent_dim=model_config.latent_dim) + else: + self.discriminator = FactorVAEDiscriminator(latent_dim=model_config.latent_dim) self.model_name = "FactorVAE" self.gamma = model_config.gamma @@ -208,6 +211,10 @@ def loss_function(self, recon_x, x, mu, log_var, z, z_bis_permuted, mask=None): .to(z.device) ) + if torch.is_complex(latent_adversarial_score): + latent_adversarial_score = torch.abs(latent_adversarial_score) + permuted_latent_adversarial_score = torch.abs(permuted_latent_adversarial_score) + TC_permuted = F.cross_entropy( latent_adversarial_score, fake_labels ) + F.cross_entropy(permuted_latent_adversarial_score, true_labels) diff --git a/src/pythae/models/factor_vae/factor_vae_utils.py b/src/pythae/models/factor_vae/factor_vae_utils.py index 62fd68ea..356cefd3 100755 --- a/src/pythae/models/factor_vae/factor_vae_utils.py +++ b/src/pythae/models/factor_vae/factor_vae_utils.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +import torchcomplex.nn as cnn class FactorVAEDiscriminator(nn.Module): @@ -23,3 +24,25 @@ def __init__(self, latent_dim=16, hidden_units=1000) -> None: def forward(self, z: torch.Tensor): return self.layers(z) + +class CVFactorVAEDiscriminator(nn.Module): + def __init__(self, latent_dim=16, hidden_units=1000) -> None: + + nn.Module.__init__(self) + + self.layers = nn.Sequential( + cnn.Linear(latent_dim, hidden_units), + cnn.CReLU(), + cnn.Linear(hidden_units, hidden_units), + cnn.CReLU(), + cnn.Linear(hidden_units, hidden_units), + cnn.CReLU(), + cnn.Linear(hidden_units, hidden_units), + cnn.CReLU(), + cnn.Linear(hidden_units, hidden_units), + cnn.CReLU(), + cnn.Linear(hidden_units, 2), + ) + + def forward(self, z: torch.Tensor): + return self.layers(z) From 151f7f382e1d2ee4dc57cd42d103383af0ea0e22 Mon Sep 17 00:00:00 2001 From: "soumick.chatterjee" Date: Sat, 27 May 2023 12:02:58 +0200 Subject: [PATCH 10/15] vae config l1 recon loss added --- src/pythae/models/vae/vae_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pythae/models/vae/vae_config.py b/src/pythae/models/vae/vae_config.py index ee904d48..231e16cc 100755 --- a/src/pythae/models/vae/vae_config.py +++ b/src/pythae/models/vae/vae_config.py @@ -14,4 +14,4 @@ class VAEConfig(BaseAEConfig): reconstruction_loss (str): The reconstruction loss to use ['bce', 'l1', 'mse']. Default: 'mse' """ - reconstruction_loss: Literal["bce", "mse"] = "mse" + reconstruction_loss: Literal["bce", "mse", "l1"] = "mse" From 101dafb4ee680311f44f5edfb858fe13b0911b38 Mon Sep 17 00:00:00 2001 From: "soumick.chatterjee" Date: Tue, 13 Jun 2023 13:57:55 +0200 Subject: [PATCH 11/15] reverting factor VAE back to original --- src/pythae/models/factor_vae/__init__.py | 0 .../models/factor_vae/factor_vae_config.py | 0 .../models/factor_vae/factor_vae_model.py | 53 ++++--------------- .../models/factor_vae/factor_vae_utils.py | 23 -------- 4 files changed, 9 insertions(+), 67 deletions(-) mode change 100755 => 100644 src/pythae/models/factor_vae/__init__.py mode change 100755 => 100644 src/pythae/models/factor_vae/factor_vae_config.py mode change 100755 => 100644 src/pythae/models/factor_vae/factor_vae_model.py mode change 100755 => 100644 src/pythae/models/factor_vae/factor_vae_utils.py diff --git a/src/pythae/models/factor_vae/__init__.py b/src/pythae/models/factor_vae/__init__.py old mode 100755 new mode 100644 diff --git a/src/pythae/models/factor_vae/factor_vae_config.py b/src/pythae/models/factor_vae/factor_vae_config.py old mode 100755 new mode 100644 diff --git a/src/pythae/models/factor_vae/factor_vae_model.py b/src/pythae/models/factor_vae/factor_vae_model.py old mode 100755 new mode 100644 index 5f7ea755..97c691eb --- a/src/pythae/models/factor_vae/factor_vae_model.py +++ b/src/pythae/models/factor_vae/factor_vae_model.py @@ -1,5 +1,5 @@ import os -from typing import Optional, Union, Callable +from typing import Optional import torch import torch.nn.functional as F @@ -10,7 +10,7 @@ from ..nn import BaseDecoder, BaseDiscriminator, BaseEncoder from ..vae import VAE from .factor_vae_config import FactorVAEConfig -from .factor_vae_utils import FactorVAEDiscriminator, CVFactorVAEDiscriminator +from .factor_vae_utils import FactorVAEDiscriminator class FactorVAE(VAE): @@ -31,11 +31,6 @@ class FactorVAE(VAE): architectures if desired. If None is provided, a simple Multi Layer Preception (https://en.wikipedia.org/wiki/Multilayer_perceptron) is used. Default: None. - custom_recon_loss_func: (torch.nn.Module or Callable): A custom loss function for calculation the reconstruction loss. - This is only used when the `reconstruction_loss` parameter in `model_config` is set to `custom`. This can be either - an instance of `torch.nn.Module` or a callable function. In either case, the function must take the following arguments: - - `recon_x`: The reconstructed data - - `x`: The original data. Default: None. .. note:: For high dimensional data we advice you to provide you own network architectures. With the @@ -47,19 +42,14 @@ def __init__( model_config: FactorVAEConfig, encoder: Optional[BaseEncoder] = None, decoder: Optional[BaseDecoder] = None, - custom_recon_loss_func: Optional[Union[torch.nn.Module, Callable]] = None, ): VAE.__init__(self, model_config=model_config, encoder=encoder, decoder=decoder) - - if model_config.is_cv: - self.discriminator = CVFactorVAEDiscriminator(latent_dim=model_config.latent_dim) - else: - self.discriminator = FactorVAEDiscriminator(latent_dim=model_config.latent_dim) + + self.discriminator = FactorVAEDiscriminator(latent_dim=model_config.latent_dim) self.model_name = "FactorVAE" self.gamma = model_config.gamma - self.custom_recon_loss_func = custom_recon_loss_func def set_discriminator(self, discriminator: BaseDiscriminator) -> None: r"""This method is called to set the discriminator network @@ -102,13 +92,6 @@ def forward(self, inputs: BaseDataset, **kwargs) -> ModelOutput: # first batch x = inputs["data"][idx_1] - if self.model_config.reconstruction_loss == "custom_masked": - if "mask" not in inputs.keys(): - raise ValueError( - "No mask not present in the input for `custom_masked` reconstruction loss" - ) - mask = inputs["mask"][idx_1] - encoder_output = self.encoder(x) mu, log_var = encoder_output.embedding, encoder_output.log_covariance @@ -129,15 +112,9 @@ def forward(self, inputs: BaseDataset, **kwargs) -> ModelOutput: z_bis_permuted = self._permute_dims(z_bis).detach() - if not self.model_config.reconstruction_loss == "custom_masked": - recon_loss, autoencoder_loss, discriminator_loss = self.loss_function( - recon_x, x, mu, log_var, z, z_bis_permuted - ) - else: - recon_loss, autoencoder_loss, discriminator_loss = self.loss_function( - recon_x, x, mu, log_var, z, z_bis_permuted, mask - ) - + recon_loss, autoencoder_loss, discriminator_loss = self.loss_function( + recon_x, x, mu, log_var, z, z_bis_permuted + ) loss = autoencoder_loss + discriminator_loss @@ -154,7 +131,7 @@ def forward(self, inputs: BaseDataset, **kwargs) -> ModelOutput: return output - def loss_function(self, recon_x, x, mu, log_var, z, z_bis_permuted, mask=None): + def loss_function(self, recon_x, x, mu, log_var, z, z_bis_permuted): N = z.shape[0] # batch size @@ -182,14 +159,6 @@ def loss_function(self, recon_x, x, mu, log_var, z, z_bis_permuted, mask=None): reduction="none", ).sum(dim=-1) - elif self.model_config.reconstruction_loss == "custom": - - recon_loss = self.custom_recon_loss_func(recon_x, x) - - elif self.model_config.reconstruction_loss == "custom_masked": - - recon_loss = self.custom_recon_loss_func(recon_x, x, mask) - KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=-1) latent_adversarial_score = self.discriminator(z) @@ -211,10 +180,6 @@ def loss_function(self, recon_x, x, mu, log_var, z, z_bis_permuted, mask=None): .to(z.device) ) - if torch.is_complex(latent_adversarial_score): - latent_adversarial_score = torch.abs(latent_adversarial_score) - permuted_latent_adversarial_score = torch.abs(permuted_latent_adversarial_score) - TC_permuted = F.cross_entropy( latent_adversarial_score, fake_labels ) + F.cross_entropy(permuted_latent_adversarial_score, true_labels) @@ -313,4 +278,4 @@ def _permute_dims(self, z): perms = torch.randperm(z.shape[0]).to(z.device) permuted[:, i] = z[perms, i] - return permuted \ No newline at end of file + return permuted diff --git a/src/pythae/models/factor_vae/factor_vae_utils.py b/src/pythae/models/factor_vae/factor_vae_utils.py old mode 100755 new mode 100644 index 356cefd3..62fd68ea --- a/src/pythae/models/factor_vae/factor_vae_utils.py +++ b/src/pythae/models/factor_vae/factor_vae_utils.py @@ -1,6 +1,5 @@ import torch import torch.nn as nn -import torchcomplex.nn as cnn class FactorVAEDiscriminator(nn.Module): @@ -24,25 +23,3 @@ def __init__(self, latent_dim=16, hidden_units=1000) -> None: def forward(self, z: torch.Tensor): return self.layers(z) - -class CVFactorVAEDiscriminator(nn.Module): - def __init__(self, latent_dim=16, hidden_units=1000) -> None: - - nn.Module.__init__(self) - - self.layers = nn.Sequential( - cnn.Linear(latent_dim, hidden_units), - cnn.CReLU(), - cnn.Linear(hidden_units, hidden_units), - cnn.CReLU(), - cnn.Linear(hidden_units, hidden_units), - cnn.CReLU(), - cnn.Linear(hidden_units, hidden_units), - cnn.CReLU(), - cnn.Linear(hidden_units, hidden_units), - cnn.CReLU(), - cnn.Linear(hidden_units, 2), - ) - - def forward(self, z: torch.Tensor): - return self.layers(z) From cc0078ee88f4d30ed5609bfd718a4bc801cea41c Mon Sep 17 00:00:00 2001 From: "Soumick Chatterjee, PhD" Date: Sat, 8 Jul 2023 01:08:44 +0200 Subject: [PATCH 12/15] ae_model updated with recon loss --- src/pythae/models/ae/ae_model.py | 31 ++++++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/src/pythae/models/ae/ae_model.py b/src/pythae/models/ae/ae_model.py index 6be833c6..addcae31 100755 --- a/src/pythae/models/ae/ae_model.py +++ b/src/pythae/models/ae/ae_model.py @@ -81,10 +81,31 @@ def forward(self, inputs: BaseDataset, **kwargs) -> ModelOutput: output = ModelOutput(loss=loss, recon_x=recon_x, z=z) return output - + def loss_function(self, recon_x, x): - MSE = F.mse_loss( - recon_x.reshape(x.shape[0], -1), x.reshape(x.shape[0], -1), reduction="none" - ).sum(dim=-1) - return MSE.mean(dim=0) + if self.model_config.reconstruction_loss == "mse": + + recon_loss = F.mse_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) + + elif self.model_config.reconstruction_loss == "bce": + + recon_loss = F.binary_cross_entropy( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) + + elif self.model_config.reconstruction_loss == "l1": + + recon_loss = F.l1_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) + + return recon_loss.mean(dim=0) From d1739e3d8d4e09132753abffcb247283136374d2 Mon Sep 17 00:00:00 2001 From: "Soumick Chatterjee, PhD" Date: Sat, 8 Jul 2023 01:09:16 +0200 Subject: [PATCH 13/15] ae_config updated with recon loss param --- src/pythae/models/ae/ae_config.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/pythae/models/ae/ae_config.py b/src/pythae/models/ae/ae_config.py index b7145d2e..2593184a 100755 --- a/src/pythae/models/ae/ae_config.py +++ b/src/pythae/models/ae/ae_config.py @@ -13,4 +13,7 @@ class AEConfig(BaseAEConfig): latent_dim (int): The latent space dimension. Default: None. default_encoder (bool): Whether the encoder default. Default: True. default_decoder (bool): Whether the encoder default. Default: True. + reconstruction_loss (str): The reconstruction loss to use ['bce', 'l1', 'mse']. Default: 'mse' """ + + reconstruction_loss: Literal["bce", "mse", "l1"] = "mse" From e7c96a829ebc2a9c656d7af2cbb260195501f818 Mon Sep 17 00:00:00 2001 From: "Soumick Chatterjee, PhD" Date: Sat, 8 Jul 2023 16:35:45 +0200 Subject: [PATCH 14/15] minor bug fix in ae_config.py --- src/pythae/models/ae/ae_config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/pythae/models/ae/ae_config.py b/src/pythae/models/ae/ae_config.py index 2593184a..681f33a0 100755 --- a/src/pythae/models/ae/ae_config.py +++ b/src/pythae/models/ae/ae_config.py @@ -1,4 +1,5 @@ from pydantic.dataclasses import dataclass +from typing_extensions import Literal from ..base.base_config import BaseAEConfig From c85c8f60e1dfebe472b45ee153a665b1bb242da8 Mon Sep 17 00:00:00 2001 From: "soumick.chatterjee" Date: Tue, 18 Jul 2023 15:17:04 +0200 Subject: [PATCH 15/15] predict added in rhvae --- src/pythae/models/rhvae/rhvae_model.py | 92 ++++++++++++++++++-------- 1 file changed, 66 insertions(+), 26 deletions(-) diff --git a/src/pythae/models/rhvae/rhvae_model.py b/src/pythae/models/rhvae/rhvae_model.py index 8ce22f7a..b37eabfc 100755 --- a/src/pythae/models/rhvae/rhvae_model.py +++ b/src/pythae/models/rhvae/rhvae_model.py @@ -263,6 +263,72 @@ def forward(self, inputs: BaseDataset, **kwargs) -> ModelOutput: return output + def predict(self, inputs: torch.Tensor) -> ModelOutput: + """The input data is encoded and decoded without computing loss + + Args: + inputs (torch.Tensor): The input data to be reconstructed, as well as to generate the embedding. + + Returns: + ModelOutput: An instance of ModelOutput containing reconstruction, raw embedding (output of encoder), and the final embedding (output of metric) + """ + encoder_output = self.encoder(inputs) + mu, log_var = encoder_output.embedding, encoder_output.log_covariance + + std = torch.exp(0.5 * log_var) + z0, _ = self._sample_gauss(mu, std) + + z = z0 + + G = self.G(z) + G_inv = self.G_inv(z) + L = torch.linalg.cholesky(G) + + G_log_det = -torch.logdet(G_inv) + + gamma = torch.randn_like(z0, device=inputs.device) + rho = gamma / self.beta_zero_sqrt + beta_sqrt_old = self.beta_zero_sqrt + + # sample \rho from N(0, G) + rho = (L @ rho.unsqueeze(-1)).squeeze(-1) + + recon_x = self.decoder(z)["reconstruction"] + + for k in range(self.n_lf): + + # perform leapfrog steps + + # step 1 + rho_ = self._leap_step_1(recon_x, inputs, z, rho, G_inv, G_log_det) + + # step 2 + z = self._leap_step_2(recon_x, inputs, z, rho_, G_inv, G_log_det) + + recon_x = self.decoder(z)["reconstruction"] + + # compute metric value on new z using final metric + G = self.G(z) + G_inv = self.G_inv(z) + + G_log_det = -torch.logdet(G_inv) + + # step 3 + rho__ = self._leap_step_3(recon_x, inputs, z, rho_, G_inv, G_log_det) + + # tempering + beta_sqrt = self._tempering(k + 1, self.n_lf) + rho = (beta_sqrt_old / beta_sqrt) * rho__ + beta_sqrt_old = beta_sqrt + + output = ModelOutput( + recon_x=recon_x, + raw_embedding=encoder_output.embedding, + embedding=z if self.n_lf > 0 else encoder_output.embedding, + ) + + return output + def _leap_step_1(self, recon_x, x, z, rho, G_inv, G_log_det, steps=3): """ Resolves first equation of generalized leapfrog integrator @@ -424,20 +490,6 @@ def _log_p_x_given_z(self, recon_x, x): reduction="none", ).sum(dim=-1) - elif self.model_config.reconstruction_loss == "l1": - # sigma is taken as I_D - recon_loss = ( - -0.5 - * F.l1_loss( - recon_x.reshape(x.shape[0], -1), - x.reshape(x.shape[0], -1), - reduction="none", - ).sum(dim=-1) - ) - -torch.log(torch.tensor([2 * np.pi]).to(x.device)) * np.prod( - self.input_dim - ) / 2 - return recon_loss def _log_z(self, z): @@ -586,18 +638,6 @@ def get_nll(self, data, n_samples=1, batch_size=100): reduction="none", ).sum(dim=-1) - elif self.model_config.reconstruction_loss == "l1": - - log_p_x_given_z = -0.5 * F.l1_loss( - recon_x.reshape(x_rep.shape[0], -1), - x_rep.reshape(x_rep.shape[0], -1), - reduction="none", - ).sum(dim=-1) - torch.tensor( - [np.prod(self.input_dim) / 2 * np.log(np.pi * 2)] - ).to( - data.device - ) # decoding distribution is assumed unit variance N(mu, I) - log_p_x.append( log_p_x_given_z + log_p_z