Skip to content

Commit

Permalink
More work on autoencoder.
Browse files Browse the repository at this point in the history
Also, use groupyr 0.3.3
  • Loading branch information
arokem committed Feb 24, 2024
1 parent e31f680 commit 68f03d5
Showing 1 changed file with 20 additions and 10 deletions.
30 changes: 20 additions & 10 deletions afqinsight/nn/tf_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,21 +336,25 @@ def fc_autoencoder(input_shape, encoding_dim=None, verbose=False):
return model


def cnn_autoencoder(input_shape, verbose=False):
def cnn_autoencoder(input_shape, encoding_dim=8, verbose=False):
"""
Convolutional autoencoder
"""
ip = Input(shape=input_shape)
# Encoder
x = Conv1D(32, (3), activation="relu", padding="same")(ip)
x = MaxPooling1D((2), padding="same")(x)
x = Conv1D(32, (3), activation="relu", padding="same")(x)
x = MaxPooling1D((2), padding="same")(x)

x = Conv1D(32, 3, activation="relu", padding="same")(ip)
x = MaxPooling1D(2, padding="same")(x)
x = Conv1D(16, 3, activation="relu", padding="same")(x)
x = MaxPooling1D(2, padding="same")(x)
shape = x.shape
# Latent
x = Flatten()(x)
x = Dense(encoding_dim, activation="relu")(x)
# Decoder
x = Conv1DTranspose(32, (3), strides=2, activation="relu", padding="same")(x)
x = Conv1DTranspose(32, (3), strides=2, activation="relu", padding="same")(x)
x = Conv1D(1, (3), activation="sigmoid", padding="same")(x)
x = Reshape(shape)(x)
x = Conv1DTranspose(32, 3, strides=2, activation="relu", padding="same")(x)
x = Conv1DTranspose(16, 3, strides=2, activation="relu", padding="same")(x)
x = Conv1DTranspose(1, 3, activation="sigmoid", padding="same")(x)

model = Model([ip], [x])
if verbose:
Expand Down Expand Up @@ -401,7 +405,8 @@ def _fc_vae_decoder(input_shape, encoding_dim=None, verbose=False):
fc = Dense((input_shape[0] * input_shape[1]) // 4, activation="relu")(fc)
fc = Dense((input_shape[0] * input_shape[1]) // 2, activation="relu")(fc)
pre_out = Dense((input_shape[0] * input_shape[1]))(fc)
return Reshape(input_shape)(pre_out)
out = Reshape(input_shape)(pre_out)
return Model([ip], [out], name="decoder")


class _VAE(Model):
Expand All @@ -427,6 +432,11 @@ def metrics(self):
self.kl_loss_tracker,
]

def call(self, inputs):
z_mean, z_log_var, z = self.encoder(inputs)
reconstructed = self.decoder(z)
return reconstructed

def train_step(self, data):
with tf.GradientTape() as tape:
z_mean, z_log_var, z = self.encoder(data)
Expand Down

0 comments on commit 68f03d5

Please sign in to comment.