From 9813162f36d5d8a353788cfa3e4e54db2523194e Mon Sep 17 00:00:00 2001 From: dn070017 Date: Wed, 20 Oct 2021 10:50:20 +0200 Subject: [PATCH] fix wrong standard deviation for vade --- models/betavae.py | 5 +++-- models/vade.py | 32 +++++++++++++++----------------- 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/models/betavae.py b/models/betavae.py index 298395f..0941b50 100644 --- a/models/betavae.py +++ b/models/betavae.py @@ -52,7 +52,8 @@ def __init__(self, latent_dim, input_dims=(28, 28, 1), kernel_size=(3, 3), strid padding='same') ]) - def elbo(self, batch, beta=1.0): + def elbo(self, batch, **kwargs): + beta = kwargs['beta'] if 'beta' in kwargs else 1.0 mean_z, logvar_z, z_sample, x_pred = self.forward(batch) logpx_z = compute_log_bernouli_pdf(x_pred, batch['x']) @@ -66,7 +67,7 @@ def elbo(self, batch, beta=1.0): def train_step(self, batch, optimizers, **kwargs): with tf.GradientTape() as tape: - elbo, logpx_z, kl_divergence = self.elbo(batch, kwargs['beta']) + elbo, logpx_z, kl_divergence = self.elbo(batch, **kwargs) gradients = tape.gradient(-1 * elbo, self.trainable_variables) optimizers['primary'].apply_gradients(zip(gradients, self.trainable_variables)) diff --git a/models/vade.py b/models/vade.py index 3f19de8..e5dc515 100644 --- a/models/vade.py +++ b/models/vade.py @@ -22,13 +22,7 @@ def __init__(self, latent_dim, input_dims=(28, 28, 1), kernel_size=(3, 3), strid self.num_components = num_components self.params = Parameters(latent_dim, num_components) - def reparameterize(self, mean, logvar): - eps = tf.random.normal(shape=mean.shape) - var = tf.keras.activations.softplus(logvar) - std = tf.math.sqrt(var) - return eps * std + mean - - def elbo(self, batch, beta=1.0): + def elbo(self, batch, **kwargs): mean_z_x, logvar_z_x = self.encode(batch) z_sample = self.reparameterize(mean_z_x, logvar_z_x) x_pred = self.decode(z_sample, apply_sigmoid=False) @@ -36,7 +30,7 @@ def elbo(self, batch, beta=1.0): # transpose mean and logvar to (batch, event) and construct multivariate Gaussian dist_z_y = tfp.distributions.MultivariateNormalDiag( tf.transpose(self.params.mean_z_y, [1, 0]), - tf.sqrt(tf.exp(0.5 * tf.transpose(self.params.logvar_z_y, [1, 0]))) #sqrt + tf.exp(0.5 * tf.transpose(self.params.logvar_z_y, [1, 0])) ) dist_y = tfp.distributions.Categorical(logits=tf.squeeze(self.params.pi_y)) @@ -44,27 +38,27 @@ def elbo(self, batch, beta=1.0): logpx_z = compute_log_bernouli_pdf(x_pred, batch['x']) logpx_z = tf.reduce_sum(logpx_z, axis=[1, 2, 3]) + logpz = dist_z.log_prob(z_sample) + logqz_x = compute_log_normal_pdf(mean_z_x, logvar_z_x, z_sample) logqz_x = tf.reduce_sum(logqz_x, axis=1) elbo = logpx_z - (logqz_x - logpz) - #print(tf.reduce_mean(elbo)) return tf.reduce_mean(elbo), tf.reduce_mean(logpx_z), tf.reduce_mean(logqz_x - logpz) def qy_x(self, batch): mean_z_x, logvar_z_x = self.encode(batch) - z_sample = self.reparameterize(mean_z_x, logvar_z_x) dist_z_y = tfp.distributions.MultivariateNormalDiag( - tf.transpose(self.params.mean_z_y, [1, 0]), - tf.sqrt(tf.exp(0.5 * tf.transpose(self.params.logvar_z_y, [1, 0]))) + tf.transpose(self.params.mean_z_y, [1, 0]), + tf.exp(0.5 * tf.transpose(self.params.logvar_z_y, [1, 0])) ) # reshape to be broadcastable (batch, batch, event) - pz_y = dist_z_y.log_prob(tf.expand_dims(z_sample, -2)) - py = tf.math.log(tf.keras.activations.softmax(self.params.pi_y) + 1e-7) - qy_x = tf.keras.activations.softmax(pz_y + py) + logpz_y = dist_z_y.log_prob(tf.expand_dims(mean_z_x, -2)) + logpy = tf.math.log(tf.keras.activations.softmax(self.params.pi_y) + 1e-7) + qy_x = tf.keras.activations.softmax(logpz_y + logpy) return qy_x @@ -74,10 +68,14 @@ def generate(self, z=None, num_generated_images=15, **kwargs): if 'target' in kwargs: target = kwargs['target'] + + temperature = kwargs['temperature'] if 'temperature' in kwargs else 0.8 + z = tf.random.normal( shape=(num_generated_images, self.latent_dim), mean=self.params.mean_z_y[:, target], - stddev=tf.exp(0.5 * self.params.logvar_z_y[:, target]), - dtype=tf.float32) + stddev=temperature * tf.exp(0.5 * self.params.logvar_z_y[:, target]), + dtype=tf.float32 + ) return self.decode(z, apply_sigmoid=True) \ No newline at end of file