Skip to content

Commit

Permalink
fix wrong standard deviation for vade
Browse files Browse the repository at this point in the history
  • Loading branch information
dn070017 committed Oct 20, 2021
1 parent 555b6f9 commit 9813162
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 19 deletions.
5 changes: 3 additions & 2 deletions models/betavae.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def __init__(self, latent_dim, input_dims=(28, 28, 1), kernel_size=(3, 3), strid
padding='same')
])

def elbo(self, batch, beta=1.0):
def elbo(self, batch, **kwargs):
beta = kwargs['beta'] if 'beta' in kwargs else 1.0
mean_z, logvar_z, z_sample, x_pred = self.forward(batch)

logpx_z = compute_log_bernouli_pdf(x_pred, batch['x'])
Expand All @@ -66,7 +67,7 @@ def elbo(self, batch, beta=1.0):

def train_step(self, batch, optimizers, **kwargs):
with tf.GradientTape() as tape:
elbo, logpx_z, kl_divergence = self.elbo(batch, kwargs['beta'])
elbo, logpx_z, kl_divergence = self.elbo(batch, **kwargs)
gradients = tape.gradient(-1 * elbo, self.trainable_variables)
optimizers['primary'].apply_gradients(zip(gradients, self.trainable_variables))

Expand Down
32 changes: 15 additions & 17 deletions models/vade.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,49 +22,43 @@ def __init__(self, latent_dim, input_dims=(28, 28, 1), kernel_size=(3, 3), strid
self.num_components = num_components
self.params = Parameters(latent_dim, num_components)

def reparameterize(self, mean, logvar):
eps = tf.random.normal(shape=mean.shape)
var = tf.keras.activations.softplus(logvar)
std = tf.math.sqrt(var)
return eps * std + mean

def elbo(self, batch, beta=1.0):
def elbo(self, batch, **kwargs):
mean_z_x, logvar_z_x = self.encode(batch)
z_sample = self.reparameterize(mean_z_x, logvar_z_x)
x_pred = self.decode(z_sample, apply_sigmoid=False)

# transpose mean and logvar to (batch, event) and construct multivariate Gaussian
dist_z_y = tfp.distributions.MultivariateNormalDiag(
tf.transpose(self.params.mean_z_y, [1, 0]),
tf.sqrt(tf.exp(0.5 * tf.transpose(self.params.logvar_z_y, [1, 0]))) #sqrt
tf.exp(0.5 * tf.transpose(self.params.logvar_z_y, [1, 0]))
)

dist_y = tfp.distributions.Categorical(logits=tf.squeeze(self.params.pi_y))
dist_z = tfp.distributions.MixtureSameFamily(dist_y, dist_z_y)

logpx_z = compute_log_bernouli_pdf(x_pred, batch['x'])
logpx_z = tf.reduce_sum(logpx_z, axis=[1, 2, 3])

logpz = dist_z.log_prob(z_sample)

logqz_x = compute_log_normal_pdf(mean_z_x, logvar_z_x, z_sample)
logqz_x = tf.reduce_sum(logqz_x, axis=1)

elbo = logpx_z - (logqz_x - logpz)
#print(tf.reduce_mean(elbo))

return tf.reduce_mean(elbo), tf.reduce_mean(logpx_z), tf.reduce_mean(logqz_x - logpz)

def qy_x(self, batch):
mean_z_x, logvar_z_x = self.encode(batch)
z_sample = self.reparameterize(mean_z_x, logvar_z_x)
dist_z_y = tfp.distributions.MultivariateNormalDiag(
tf.transpose(self.params.mean_z_y, [1, 0]),
tf.sqrt(tf.exp(0.5 * tf.transpose(self.params.logvar_z_y, [1, 0])))
tf.transpose(self.params.mean_z_y, [1, 0]),
tf.exp(0.5 * tf.transpose(self.params.logvar_z_y, [1, 0]))
)

# reshape to be broadcastable (batch, batch, event)
pz_y = dist_z_y.log_prob(tf.expand_dims(z_sample, -2))
py = tf.math.log(tf.keras.activations.softmax(self.params.pi_y) + 1e-7)
qy_x = tf.keras.activations.softmax(pz_y + py)
logpz_y = dist_z_y.log_prob(tf.expand_dims(mean_z_x, -2))
logpy = tf.math.log(tf.keras.activations.softmax(self.params.pi_y) + 1e-7)
qy_x = tf.keras.activations.softmax(logpz_y + logpy)

return qy_x

Expand All @@ -74,10 +68,14 @@ def generate(self, z=None, num_generated_images=15, **kwargs):

if 'target' in kwargs:
target = kwargs['target']

temperature = kwargs['temperature'] if 'temperature' in kwargs else 0.8

z = tf.random.normal(
shape=(num_generated_images, self.latent_dim),
mean=self.params.mean_z_y[:, target],
stddev=tf.exp(0.5 * self.params.logvar_z_y[:, target]),
dtype=tf.float32)
stddev=temperature * tf.exp(0.5 * self.params.logvar_z_y[:, target]),
dtype=tf.float32
)

return self.decode(z, apply_sigmoid=True)

0 comments on commit 9813162

Please sign in to comment.