Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
Initial commit
  • Loading branch information
yaysummeriscoming authored Oct 31, 2017
1 parent aed04fc commit 89db885
Show file tree
Hide file tree
Showing 48 changed files with 748 additions and 0 deletions.
145 changes: 145 additions & 0 deletions BinaryNet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import os
import numpy as np

from CustomOps.customOps import SetSession

# Call this first here, to make sure that Tensorflow registers our custom ops properly
SetSession()

from keras.models import Model, load_model
from keras.layers import Dense, Flatten, Input
from keras.optimizers import Adam
from keras.datasets import mnist, cifar100, cifar10
from keras.utils import np_utils
from keras.callbacks import ModelCheckpoint
from keras import backend as K

from CompositeLayers.ConvBNReluLayer import ConvBNReluLayer
from CompositeLayers.BinaryNetConvBNReluLayer import BinaryNetConvBNReluLayer, BinaryNetActivation
from CustomLayers.CustomLayersDictionary import customLayersDictionary
from CompositeLayers.XNORConvLayer import XNORConvBNReluLayer, BNXNORConvReluLayer
from NetworkParameters import NetworkParameters
from CustomLayers.CustomLayersDictionary import customLayerCallbacks

np.random.seed(1337) # for reproducibility


def CreateModel(input_shape, nb_classes, parameters):
model_input = Input(shape=input_shape)

output = model_input

if parameters.binarisation_type == 'BinaryNet':
print('Using BinaryNet binary convolution layers')
layerType = BinaryNetConvBNReluLayer
elif parameters.binarisation_type == 'XNORNet':
print('Using XNORNet binary convolution layers')
layerType = BNXNORConvReluLayer
else:
assert False, 'Unsupported binarisation type!'

# As per the paper, the first layer can't be binary
output = ConvBNReluLayer(input=output, nb_filters=16, border='valid', kernel_size=(3, 3), stride=(1, 1))

# Add an extra binarisation layer here, as with Theano need input binarisation
if K.backend() == 'theano':
output = BinaryNetActivation()(output)

output = layerType(input=output, nb_filters=32, border='valid', kernel_size=(3, 3), stride=(1, 1))
output = layerType(input=output, nb_filters=32, border='valid', kernel_size=(3, 3), stride=(1, 1))
output = layerType(input=output, nb_filters=32, border='valid', kernel_size=(3, 3), stride=(1, 1))
output = layerType(input=output, nb_filters=32, border='valid', kernel_size=(3, 3), stride=(1, 1))
output = layerType(input=output, nb_filters=32, border='valid', kernel_size=(3, 3), stride=(1, 1))

output = Flatten()(output)
output = Dense(nb_classes, use_bias=True, activation='softmax')(output)

model = Model(inputs=model_input, outputs=output)

model.summary()

return model



############################
# Parameters


if os.name == 'posix':
modelDirectory = '/Volumes/MacStorage/DeepLearningData/current_SENet/'
else:
modelDirectory = 'F:/DeepLearning/models/current_SENet/'


parameters = NetworkParameters(modelDirectory)
parameters.nb_epochs = 1
parameters.batch_size = 32
parameters.lr = 0.0005
parameters.batch_scale_factor = 8
parameters.decay = 0.001

parameters.binarisation_type = 'BinaryNet' # Either 'BinaryNet' or 'XNORNet'

parameters.lr *= parameters.batch_scale_factor
parameters.batch_size *= parameters.batch_scale_factor

print('Learning rate is: %f' % parameters.lr)
print('Batch size is: %d' % parameters.batch_size)

optimiser = Adam(lr=parameters.lr, decay=parameters.decay)

############################
# Data

(X_train, y_train), (X_test, y_test) = mnist.load_data()

y_train = np.squeeze(y_train)
y_test = np.squeeze(y_test)

if len(X_train.shape) < 4:
X_train = np.expand_dims(X_train, -1)
X_test = np.expand_dims(X_test, -1)

input_shape = X_train.shape[1:]

X_train = X_train.astype('float32')
X_test = X_test.astype('float32')

X_train = X_train / 256.0
X_test = X_test / 256.0

nb_classes = y_train.max() + 1

y_test_cat = np_utils.to_categorical(y_test, nb_classes + 1)
y_train_cat = np_utils.to_categorical(y_train, nb_classes + 1)


############################
# Training

model = CreateModel(input_shape=input_shape, nb_classes=nb_classes+1, parameters=parameters)

model.compile(loss='categorical_crossentropy',
optimizer=optimiser,
metrics=['accuracy'])

checkpointCallback = ModelCheckpoint(filepath=parameters.modelSaveName, verbose=1)
bestCheckpointCallback = ModelCheckpoint(filepath=parameters.bestModelSaveName, verbose=1, save_best_only=True)

model.fit(x=X_train,
y=y_train_cat,
batch_size=parameters.batch_size,
epochs=parameters.nb_epochs,
callbacks=[checkpointCallback, bestCheckpointCallback] + customLayerCallbacks,
validation_data=(X_test, y_test_cat),
shuffle=True,
verbose=1
)


print('Testing')
modelTest = load_model(filepath=parameters.bestModelSaveName, custom_objects=customLayersDictionary)

validationAccuracy = model.evaluate(X_test, y_test_cat, verbose=0)
print('\nBest Keras validation accuracy is : %f \n' % (100.0 * validationAccuracy[1]))
31 changes: 31 additions & 0 deletions CompositeLayers/BinaryNetConvBNReluLayer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from keras.layers import BatchNormalization, Activation
from CustomLayers.BinaryNetLayer import BinaryNetConv2D, BinaryNetActivation
from keras import backend as K

def BinaryNetConvBNReluLayer(input, nb_filters, border, kernel_size, stride, use_bias=True, data_format='channels_last', use_activation=False):
output = input

# BinaryNet uses binarisation as the activation
# To get the graphs to compile properly, add binarisation as a seperate layer to the input for theano
# The tensorflow implementation contains the input binarisation inside the layer definition
if K.backend() == 'theano':
output = BinaryNetActivation()(output)

output = BinaryNetConv2D(nb_filters,
kernel_size,
use_bias=use_bias,
padding=border,
strides=stride,
data_format=data_format,
)(output)

# Add output binarisation as a seperate layer for Theano
if K.backend() == 'theano':
output = BinaryNetActivation()(output)

output = BatchNormalization()(output)

if use_activation:
output = Activation('relu')(output)

return output
Binary file added CompositeLayers/BinaryNetConvBNReluLayer.pyc
Binary file not shown.
17 changes: 17 additions & 0 deletions CompositeLayers/ConvBNReluLayer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from keras.layers import Activation, BatchNormalization
from keras.layers import Convolution2D

def ConvBNReluLayer(input, nb_filters, border, kernel_size, stride, use_bias=True, data_format='channels_last'):

output = Convolution2D(filters=nb_filters,
kernel_size=kernel_size,
strides=stride,
padding=border,
data_format=data_format,
use_bias=use_bias
)(input)

output = BatchNormalization()(output)
output = Activation('relu')(output)

return output
Binary file added CompositeLayers/ConvBNReluLayer.pyc
Binary file not shown.
63 changes: 63 additions & 0 deletions CompositeLayers/XNORConvLayer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from keras.layers import Activation, BatchNormalization
from CustomLayers.XNORNetLayer import XNORNetConv2D

def BNXNORConvReluLayer(input,
nb_filters,
border,
kernel_size,
stride,
use_BN=True,
use_bias=False,
use_activation=True,
binarise_input=True,
data_format='channels_last'):

output = input

if use_BN:
output = BatchNormalization()(output)

output = XNORNetConv2D(filters=nb_filters,
kernel_size=kernel_size,
use_bias=use_bias,
padding=border,
strides=stride,
data_format=data_format,
binarise_input=binarise_input
)(output)

if use_activation:
output = Activation('relu')(output)

return output


def XNORConvBNReluLayer(input,
nb_filters,
border,
kernel_size,
stride,
use_BN=True,
use_bias=False,
use_activation=True,
binarise_input=True,
data_format='channels_last'):

output = input

output = XNORNetConv2D(nb_filters=nb_filters,
kernel_size=kernel_size,
use_bias=use_bias,
padding=border,
strides=stride,
data_format=data_format,
binarise_input=binarise_input
)(output)

if use_BN:
output = BatchNormalization()(output)

if use_activation:
output = Activation('relu')(output)

return output
Binary file added CompositeLayers/__init__.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added CompositeLayers/__pycache__/__init__.cpython-35.pyc
Binary file not shown.
105 changes: 105 additions & 0 deletions CustomLayers/BinaryNetLayer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import keras.backend as K
import numpy as np
from keras.engine import InputSpec
from keras.engine import Layer
from keras.layers import Convolution2D

from CustomOps.customOps import passthroughSign


class BinaryNetActivation(Layer):

def __init__(self, **kwargs):
super(BinaryNetActivation, self).__init__(**kwargs)
# self.supports_masking = True

def build(self, input_shape):
super(BinaryNetActivation, self).build(input_shape) # Be sure to call this somewhere!

def call(self, inputs):
# In BinaryNet, the output activation is binarised (normally done at the input to each layer in our implementation)
return passthroughSign(inputs)

def get_config(self):
base_config = super(BinaryNetActivation, self).get_config()
return base_config

def compute_output_shape(self, input_shape):
return input_shape

class BinaryNetConv2D(Convolution2D):
"""2D binary convolution layer (e.g. spatial convolution over images).
This is an implementation of the BinaryNet layer described in:
Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1
It's based off the Convolution2D class, featuring an idential argument list.
NOTE: The weight binarisation functionality is implemented using a 'on batch end' function,
which must be called at the end of every batch (ideally using a callback). Currently this functionality
is implemented using Numpy. In practice this incurs a negligible performance penalty,
as this function uses far fewer operations than the base convolution operation.
# Arguments
Same as base Convolution2D layer
# Input shape
4D tensor with shape:
`(samples, channels, rows, cols)` if data_format='channels_first'
or 4D tensor with shape:
`(samples, rows, cols, channels)` if data_format='channels_last'.
# Output shape
4D tensor with shape:
`(samples, filters, new_rows, new_cols)` if data_format='channels_first'
or 4D tensor with shape:
`(samples, new_rows, new_cols, filters)` if data_format='channels_last'.
`rows` and `cols` values might have changed due to padding.
"""

def build(self, input_shape):
# Call the build function of the base class (in this case, convolution)
# super(BinaryNetConv2D, self).build(input_shape) # Be sure to call this somewhere!
super().build(input_shape) # Be sure to call this somewhere!

print('Use bias is: %d' % self.use_bias)

# Get the initialised weights as save as the 'full precision' weights
weights = K.get_value(self.weights[0])
self.fullPrecisionWeights = weights.copy()

# Compute the binary approximated weights & save ready for the first batch
B = np.sign(self.fullPrecisionWeights)
self.lastIterationWeights = B.copy()
K.set_value(self.weights[0], B)


def call(self, inputs):

# For theano, binarisation is done as a seperate layer
if K.backend() == 'tensorflow':
binarisedInput = passthroughSign(inputs)
else:
binarisedInput = inputs

return super().call(binarisedInput)


def on_batch_end(self):
# Weight arrangement is: (kernel_size, kernel_size, num_input_channels, num_output_channels)
# for both data formats in keras 2 notation

# Work out the weights update from the last batch and then apply this to the full precision weights
# The current weights correspond to the binarised weights + last batch update
newWeights = K.get_value(self.weights[0])
weightsUpdate = newWeights - self.lastIterationWeights
self.fullPrecisionWeights += weightsUpdate
self.fullPrecisionWeights = np.clip(self.fullPrecisionWeights, -1., 1.)

# Work out new approximated weights based off the full precision values
B = np.sign(self.fullPrecisionWeights)

# Save the weights, both in the keras kernel and a reference variable
# so that we can compute the weights update that keras makes
self.lastIterationWeights = B.copy()
K.set_value(self.weights[0], B)
Binary file added CustomLayers/BinaryNetLayer.pyc
Binary file not shown.
Loading

0 comments on commit 89db885

Please sign in to comment.