Skip to content
This repository has been archived by the owner on Nov 3, 2022. It is now read-only.

Added »concrete droupout« Wrapper #463

Open
wants to merge 9 commits into
base: master
Choose a base branch
from

Conversation

kirschte
Copy link

@kirschte kirschte commented Feb 18, 2019

- What I did
Implemented the Concrete Dropout paper w.r.t. yaringal's implementation (determines the optimal dropout rate on its own).

- How I did it
Used a keras wrapper. Modifying the keras Dense/Conv2d-Layers itself would result in more elegant code, but this version should work as well.

- How you can verify it
With the attached test file and with the following example call:

from __future__ import print_function

import keras
from keras.datasets import mnist
from keras.callbacks import Callback
from keras.models import Sequential, Model
from keras.layers import Dense, Dropout, Input
from keras.optimizers import RMSprop

import numpy as np

from cdropout import ConcreteDropout

class CD_Callback(Callback):
    def set_model(self, model):
        super(CD_Callback, self).set_model(model)
        self.probs = {}
        for layer in self.model.layers:
            if isinstance(layer, ConcreteDropout):
                self.probs[layer.name] = []
    def on_epoch_end(self, batch, logs=None):
        for layer in self.model.layers:
            if isinstance(layer, ConcreteDropout):
                self.probs[layer.name] += [np.squeeze(np.exp(layer.get_weights()[2]))]
                print('{} Propability for »Concrete Dropout«: {:.3f}' \
                    .format(layer.name, np.squeeze(np.exp(layer.get_weights()[2]))))

batch_size = 4096
num_classes = 10
epochs = 501

# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.reshape(60000, 784)
x_test = x_test.reshape(10000, 784)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

inp = Input(shape=(784,))
model = ConcreteDropout(Dense(512, activation='relu'),
                        len(x_train),
                        seed=42,
                        name='concrete_dropout_input')(inp)
model = ConcreteDropout(Dense(512, activation='relu'),
                        len(x_train),
                        seed=42,
                        name='concrete_dropout_hidden')(model)
model = ConcreteDropout(Dense(num_classes, activation='softmax'),
                        len(x_train),
                        seed=42,
                        name='concrete_dropout_output')(model)
model = Model(inp, model)

model.summary()
model.compile(loss='categorical_crossentropy',
              optimizer=RMSprop(),
              metrics=['accuracy'])
callback = CD_Callback()
history = model.fit(x_train, y_train,
                    batch_size=batch_size,
                    epochs=epochs,
                    verbose=1,
                    validation_data=(x_test, y_test),
                    callbacks = [ callback ])
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

The dropout rate should converge to respecively 0.29, 0.33 and 0.82 for each layer in above example as visualized in the following plot:
dropout convergence

@kirschte kirschte force-pushed the master branch 2 times, most recently from df29773 to 0790c09 Compare February 18, 2019 20:39
@kirschte kirschte changed the title finally made cdroupout for keras-contrib available Added »concrete droupout« Feb 18, 2019
@kirschte kirschte changed the title Added »concrete droupout« Added »concrete droupout« Wrapper Feb 19, 2019
@kirschte
Copy link
Author

One test finally fails due to keras issue. See #12305

@kirschte kirschte force-pushed the master branch 7 times, most recently from c1af8c4 to ba3eb0b Compare February 21, 2019 13:37
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant