Skip to content
This repository has been archived by the owner on Nov 3, 2022. It is now read-only.

Commit

Permalink
refactored tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kirschte committed Feb 19, 2019
1 parent 9007618 commit 95214b6
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 19 deletions.
15 changes: 9 additions & 6 deletions keras_contrib/wrappers/cdropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from keras.initializers import RandomUniform
from keras.layers import InputSpec
from keras.layers.wrappers import Wrapper
from keras_contrib.utils.test_utils import to_tuple


class ConcreteDropout(Wrapper):
Expand Down Expand Up @@ -34,8 +35,9 @@ class ConcreteDropout(Wrapper):
Also known as inverse observation noise.
prob_init: Tuple[float, float].
Probability lower / upper bounds of dropout rate initialization.
temp: float. Temperature. Not used to be optimized.
seed: Seed for random probability sampling.
temp: float. Temperature.
Determines the speed of probability adjustments.
seed: Seed for random probability sampling.
# References
- [Concrete Dropout](https://arxiv.org/pdf/1705.07832.pdf)
Expand All @@ -44,10 +46,10 @@ class ConcreteDropout(Wrapper):
def __init__(self,
layer,
n_data,
length_scale=2e-2,
length_scale=5e-2,
model_precision=1,
prob_init=(0.1, 0.5),
temp=0.1,
temp=0.4,
seed=None,
**kwargs):
assert 'kernel_regularizer' not in kwargs
Expand All @@ -64,7 +66,7 @@ def __init__(self,

def _concrete_dropout(self, inputs, layer_type):
"""Applies concrete dropout.
Used at training time (gradients can be propagated)
Used at training time (gradients can be propagated).
# Arguments
inputs: Input.
Expand Down Expand Up @@ -99,6 +101,7 @@ def _concrete_dropout(self, inputs, layer_type):
return inputs

def build(self, input_shape=None):
input_shape = to_tuple(input_shape)
if len(input_shape) == 2: # Dense_layer
input_dim = np.prod(input_shape[-1]) # we drop only last dim
elif len(input_shape) == 4: # Conv_layer
Expand Down Expand Up @@ -126,7 +129,7 @@ def build(self, input_shape=None):

super(ConcreteDropout, self).build(input_shape)

# initialise regularizer / prior KL term
# initialize regularizer / prior KL term
weight = self.layer.kernel
kernel_regularizer = (
self.weight_regularizer
Expand Down
101 changes: 88 additions & 13 deletions tests/keras_contrib/wrappers/test_cdropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from keras_contrib.wrappers import ConcreteDropout


def test_cdropout():
@pytest.fixture(scope='module')
def dense_model():
"""Initialize to be tested dense model. Executed once.
"""
# DATA
in_dim = 20
init_prop = .1
Expand All @@ -28,34 +31,72 @@ def test_cdropout():
# Model, reference w/o Dropout
x_ref = dense(inputs)
model_ref = Model(inputs, x_ref)
model_ref.compile(loss='mse', optimizer='rmsprop')
model_ref.compile(loss=None, optimizer='rmsprop')

return {'model': model,
'model_ref': model_ref,
'concrete_dropout': cd,
'init_prop': init_prop,
'in_dim': in_dim,
'X': X}


def test_cdropout_dense_3rdweight(dense_model):
"""Check about correct 3rd weight (equal to initial value)
"""
model = dense_model['model']
init_prop = dense_model['init_prop']

# CHECKS
# Check about correct 3rd weight (equal to initial value)
W = model.get_weights()
assert_array_almost_equal(W[2], [np.log(init_prop)])

# Check if ConcreteDropout in prediction phase is the same as no dropout

def test_cdropout_dense_identity(dense_model):
"""Check if ConcreteDropout in prediction phase is the same as no dropout
"""
model = dense_model['model']
model_ref = dense_model['model_ref']
X = dense_model['X']

out = model.predict(X)
out_ref = model_ref.predict(X)
assert_allclose(out, out_ref, atol=1e-5)

# Check if ConcreteDropout has the right amount of losses deposited

def test_cdropout_dense_loss(dense_model):
"""Check if ConcreteDropout has the right amount of losses deposited
"""
model = dense_model['model']

assert_equal(len(model.losses), 1)

# Check if the loss correspons the the desired value

def test_cdropout_dense_loss_value(dense_model):
"""Check if the loss corresponds the the desired value
"""
model = dense_model['model']
X = dense_model['X']
cd = dense_model['concrete_dropout']
in_dim = dense_model['in_dim']

def sigmoid(x):
return 1. / (1. + np.exp(-x))

W = model.get_weights()
p = np.squeeze(sigmoid(W[2]))
kernel_regularizer = cd.weight_regularizer * np.sum(np.square(W[0])) / (1. - p)
dropout_regularizer = (p * np.log(p) + (1. - p) * np.log(1. - p))
dropout_regularizer *= cd.dropout_regularizer * in_dim
loss = np.sum(kernel_regularizer + dropout_regularizer)

eval_loss = model.evaluate(X)
assert_approx_equal(eval_loss, loss)


def test_cdropout_conv():
@pytest.fixture(scope='module')
def conv2d_model():
"""Initialize to be tested conv model. Executed once.
"""
# DATA
in_dim = 20
init_prop = .1
Expand All @@ -75,27 +116,61 @@ def test_cdropout_conv():
model_ref = Model(inputs, x_ref)
model_ref.compile(loss=None, optimizer='rmsprop')

# CHECKS
# Check about correct 3rd weight (equal to initial value)
return {'model': model,
'model_ref': model_ref,
'concrete_dropout': cd,
'init_prop': init_prop,
'in_dim': in_dim,
'X': X}


def test_cdropout_conv2d_3rdweight(conv2d_model):
"""Check about correct 3rd weight (equal to initial value)
"""
model = conv2d_model['model']
init_prop = conv2d_model['init_prop']

W = model.get_weights()
assert_array_almost_equal(W[2], [np.log(init_prop)])

# Check if ConcreteDropout in prediction phase is the same as no dropout

def test_cdropout_conv2d_identity(conv2d_model):
"""Check if ConcreteDropout in prediction phase is the same as no dropout
"""
model = conv2d_model['model']
model_ref = conv2d_model['model_ref']
X = conv2d_model['X']

out = model.predict(X)
out_ref = model_ref.predict(X)
assert_allclose(out, out_ref, atol=1e-5)

# Check if ConcreteDropout has the right amount of losses deposited

def test_cdropout_conv2d_loss(conv2d_model):
"""Check if ConcreteDropout has the right amount of losses deposited
"""
model = conv2d_model['model']

assert_equal(len(model.losses), 1)

# Check if the loss correspons the the desired value

def test_cdropout_conv2d_loss_value(conv2d_model):
"""Check if the loss corresponds the the desired value
"""
model = conv2d_model['model']
X = conv2d_model['X']
cd = conv2d_model['concrete_dropout']

def sigmoid(x):
return 1. / (1. + np.exp(-x))

W = model.get_weights()
p = np.squeeze(sigmoid(W[2]))
kernel_regularizer = cd.weight_regularizer * np.sum(np.square(W[0])) / (1. - p)
dropout_regularizer = (p * np.log(p) + (1. - p) * np.log(1. - p))
dropout_regularizer *= cd.dropout_regularizer * 1 # only channels are dropped
loss = np.sum(kernel_regularizer + dropout_regularizer)

eval_loss = model.evaluate(X)
assert_approx_equal(eval_loss, loss)

Expand Down

0 comments on commit 95214b6

Please sign in to comment.