Skip to content

Commit

Permalink
Simplify autoencoder by using identity layer
Browse files Browse the repository at this point in the history
  • Loading branch information
Dobiasd committed Feb 10, 2025
1 parent 2c038c6 commit f4a23f4
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 9 deletions.
1 change: 1 addition & 0 deletions include/fdeep/import_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1142,6 +1142,7 @@ namespace internal {
const std::string name = data["name"];

const layer_creators default_creators = {
{ "Identity", create_identity_layer },
{ "Conv1D", create_conv_2d_layer },
{ "Conv2D", create_conv_2d_layer },
{ "SeparableConv1D", create_separable_conv_2D_layer },
Expand Down
16 changes: 7 additions & 9 deletions keras_export/generate_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
__maintainer__ = "Tobias Hermann, https://github.com/Dobiasd/frugally-deep"
__email__ = "[email protected]"

from keras.src.layers import Identity


def replace_none_with(value, shape):
"""Replace every None with a fixed value."""
Expand Down Expand Up @@ -662,19 +664,15 @@ def get_test_model_variable():

def get_test_model_autoencoder():
"""Returns a minimal autoencoder test model."""
input_img = Input(shape=(28, 28, 1), name='input_img')
x = Conv2D(4, (7, 7), activation='relu', padding='same', name='conv_encoder')(input_img)
x = MaxPooling2D((4, 4), padding='same', name='pool_encoder')(x)
encoded = Conv2D(1, (3, 3), activation='relu', padding='same', name='conv2_encoder')(x)
input_img = Input(shape=(1,), name='input_img')
encoded = Identity()(input_img) # Since it's about testing node connections, this suffices.
encoder = Model(input_img, encoded, name="encoder")

input_encoded = Input(shape=(7, 7, 1), name='input_encoded')
x = Conv2D(4, (3, 3), activation='relu', padding='same', name='conv_decoder')(input_encoded)
x = UpSampling2D((4, 4), name='upsampling_encoder')(x)
decoded = Conv2D(1, (7, 7), activation='sigmoid', padding='same', name='conv3_decoder')(x)
input_encoded = Input(shape=(1,), name='input_encoded')
decoded = Identity()(input_encoded)
decoder = Model(input_encoded, decoded, name="decoder")

autoencoder_input = Input(shape=(28, 28, 1), name='input_autoencoder')
autoencoder_input = Input(shape=(1,), name='input_autoencoder')
x = encoder(autoencoder_input)
autoencodedanddecoded = decoder(x)
autoencoder = Model(inputs=autoencoder_input, outputs=autoencodedanddecoded, name="autoencoder")
Expand Down

0 comments on commit f4a23f4

Please sign in to comment.