diff --git a/README.md b/README.md index def0674..f64e8c0 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,8 @@ and is distributed under the MIT license. | [MobileNetV2(alpha=1.0)](keras_applications/mobilenet_v2.py) | 224 | 71.336 | 90.142 | 3.5M | 2.3M | [[paper]](https://arxiv.org/abs/1801.04381) [[tf-models]](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py) | | [MobileNetV2(alpha=1.3)](keras_applications/mobilenet_v2.py) | 224 | 74.680 | 92.122 | 5.4M | 3.8M | [[paper]](https://arxiv.org/abs/1801.04381) [[tf-models]](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py) | | [MobileNetV2(alpha=1.4)](keras_applications/mobilenet_v2.py) | 224 | 75.230 | 92.422 | 6.2M | 4.4M | [[paper]](https://arxiv.org/abs/1801.04381) [[tf-models]](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py) | +| [MobileNetV3(small)](keras_applications/mobilenet_v3.py) | 224 | 68.076 | 87.800 | 2.6M | 0.9M | [[paper]](https://arxiv.org/abs/1905.02244) [[tf-models]](https://github.com/tensorflow/models/tree/master/research/slim/nets/mobilenet/mobilenet_v3.py) | +| [MobileNetV3(large)](keras_applications/mobilenet_v3.py) | 224 | 75.556 | 92.708 | 5.5M | 3.0M | [[paper]](https://arxiv.org/abs/1905.02244) [[tf-models]](https://github.com/tensorflow/models/tree/master/research/slim/nets/mobilenet/mobilenet_v3.py) | | [DenseNet121](keras_applications/densenet.py) | 224 | 74.972 | 92.258 | 8.1M | 7.0M | [[paper]](https://arxiv.org/abs/1608.06993) [[torch]](https://github.com/liuzhuang13/DenseNet/blob/master/models/densenet.lua) | | [DenseNet169](keras_applications/densenet.py) | 224 | 76.176 | 93.176 | 14.3M | 12.6M | [[paper]](https://arxiv.org/abs/1608.06993) [[torch]](https://github.com/liuzhuang13/DenseNet/blob/master/models/densenet.lua) | | [DenseNet201](keras_applications/densenet.py) | 224 | 77.320 | 93.620 | 20.2M | 18.3M | [[paper]](https://arxiv.org/abs/1608.06993) [[torch]](https://github.com/liuzhuang13/DenseNet/blob/master/models/densenet.lua) | diff --git a/keras_applications/__init__.py b/keras_applications/__init__.py index 1193048..dd6f6b6 100644 --- a/keras_applications/__init__.py +++ b/keras_applications/__init__.py @@ -57,6 +57,7 @@ def correct_pad(backend, inputs, kernel_size): from . import xception from . import mobilenet from . import mobilenet_v2 +from . import mobilenet_v3 from . import densenet from . import nasnet from . import resnet diff --git a/keras_applications/mobilenet_v3.py b/keras_applications/mobilenet_v3.py new file mode 100644 index 0000000..c404c52 --- /dev/null +++ b/keras_applications/mobilenet_v3.py @@ -0,0 +1,549 @@ +"""MobileNet v3 models for Keras. + +The following table describes the performance of MobileNets: +------------------------------------------------------------------------ +MACs stands for Multiply Adds + +| Classification Checkpoint| MACs(M)| Parameters(M)| Top1 Accuracy| Pixel1 CPU(ms)| + +| [mobilenet_v3_large_1.0_224] | 217 | 5.4 | 75.6 | 51.2 | +| [mobilenet_v3_large_0.75_224] | 155 | 4.0 | 73.3 | 39.8 | +| [mobilenet_v3_large_minimalistic_1.0_224] | 209 | 3.9 | 72.3 | 44.1 | +| [mobilenet_v3_small_1.0_224] | 66 | 2.9 | 68.1 | 15.8 | +| [mobilenet_v3_small_0.75_224] | 44 | 2.4 | 65.4 | 12.8 | +| [mobilenet_v3_small_minimalistic_1.0_224] | 65 | 2.0 | 61.9 | 12.2 | + +The weights for all 6 models are obtained and +translated from the Tensorflow checkpoints +from TensorFlow checkpoints found [here] +(https://github.com/tensorflow/models/tree/master/research/ +slim/nets/mobilenet/README.md). + +# Reference + +This file contains building code for MobileNetV3, based on +[Searching for MobileNetV3] +(https://arxiv.org/pdf/1905.02244.pdf) (ICCV 2019) + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import warnings + +from . import correct_pad +from . import get_submodules_from_kwargs +from . import imagenet_utils +from .imagenet_utils import _obtain_input_shape +from .imagenet_utils import decode_predictions + + +backend = None +layers = None +models = None +keras_utils = None + +BASE_WEIGHT_PATH = ('https://github.com/DrSlink/mobilenet_v3_keras/' + 'releases/download/v1.0/') +WEIGHTS_HASHES = { + 'large_224_0.75_float': ( + '765b44a33ad4005b3ac83185abf1d0eb', + 'c256439950195a46c97ede7c294261c6'), + 'large_224_1.0_float': ( + '59e551e166be033d707958cf9e29a6a7', + '12c0a8442d84beebe8552addf0dcb950'), + 'large_minimalistic_224_1.0_float': ( + '675e7b876c45c57e9e63e6d90a36599c', + 'c1cddbcde6e26b60bdce8e6e2c7cae54'), + 'small_224_0.75_float': ( + 'cb65d4e5be93758266aa0a7f2c6708b7', + 'c944bb457ad52d1594392200b48b4ddb'), + 'small_224_1.0_float': ( + '8768d4c2e7dee89b9d02b2d03d65d862', + '5bec671f47565ab30e540c257bba8591'), + 'small_minimalistic_224_1.0_float': ( + '99cd97fb2fcdad2bf028eb838de69e37', + '1efbf7e822e03f250f45faa3c6bbe156'), +} + + +def preprocess_input(x, **kwargs): + """Preprocesses a numpy array encoding a batch of images. + + # Arguments + x: a 4D numpy array consists of RGB values within [0, 255]. + + # Returns + Preprocessed array. + """ + return imagenet_utils.preprocess_input(x, mode='tf', **kwargs) + + +def relu(x): + return layers.ReLU()(x) + + +def hard_sigmoid(x): + return layers.ReLU(6.)(x + 3.) * (1. / 6.) + + +def hard_swish(x): + return layers.Multiply()([layers.Activation(hard_sigmoid)(x), x]) + + +# This function is taken from the original tf repo. +# It ensures that all layers have a channel number that is divisible by 8 +# It can be seen here: +# https://github.com/tensorflow/models/blob/master/research/ +# slim/nets/mobilenet/mobilenet.py + + +def _depth(v, divisor=8, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +def _se_block(inputs, filters, se_ratio, prefix): + x = layers.GlobalAveragePooling2D(name=prefix + 'squeeze_excite/AvgPool')(inputs) + if backend.image_data_format() == 'channels_first': + x = layers.Reshape((filters, 1, 1))(x) + else: + x = layers.Reshape((1, 1, filters))(x) + x = layers.Conv2D(_depth(filters * se_ratio), + kernel_size=1, + padding='same', + name=prefix + 'squeeze_excite/Conv')(x) + x = layers.ReLU(name=prefix + 'squeeze_excite/Relu')(x) + x = layers.Conv2D(filters, + kernel_size=1, + padding='same', + name=prefix + 'squeeze_excite/Conv_1')(x) + x = layers.Activation(hard_sigmoid)(x) + if backend.backend() == 'theano': + # For the Theano backend, we have to explicitly make + # the excitation weights broadcastable. + x = layers.Lambda( + lambda br: backend.pattern_broadcast(br, [True, True, True, False]), + output_shape=lambda input_shape: input_shape, + name=prefix + 'squeeze_excite/broadcast')(x) + x = layers.Multiply(name=prefix + 'squeeze_excite/Mul')([inputs, x]) + return x + + +def _inverted_res_block(x, expansion, filters, kernel_size, stride, + se_ratio, activation, block_id): + channel_axis = 1 if backend.image_data_format() == 'channels_first' else -1 + shortcut = x + prefix = 'expanded_conv/' + infilters = backend.int_shape(x)[channel_axis] + if block_id: + # Expand + prefix = 'expanded_conv_{}/'.format(block_id) + x = layers.Conv2D(_depth(infilters * expansion), + kernel_size=1, + padding='same', + use_bias=False, + name=prefix + 'expand')(x) + x = layers.BatchNormalization(axis=channel_axis, + epsilon=1e-3, + momentum=0.999, + name=prefix + 'expand/BatchNorm')(x) + x = layers.Activation(activation)(x) + + if stride == 2: + x = layers.ZeroPadding2D(padding=correct_pad(backend, x, kernel_size), + name=prefix + 'depthwise/pad')(x) + x = layers.DepthwiseConv2D(kernel_size, + strides=stride, + padding='same' if stride == 1 else 'valid', + use_bias=False, + name=prefix + 'depthwise')(x) + x = layers.BatchNormalization(axis=channel_axis, + epsilon=1e-3, + momentum=0.999, + name=prefix + 'depthwise/BatchNorm')(x) + x = layers.Activation(activation)(x) + + if se_ratio: + x = _se_block(x, _depth(infilters * expansion), se_ratio, prefix) + + x = layers.Conv2D(filters, + kernel_size=1, + padding='same', + use_bias=False, + name=prefix + 'project')(x) + x = layers.BatchNormalization(axis=channel_axis, + epsilon=1e-3, + momentum=0.999, + name=prefix + 'project/BatchNorm')(x) + + if stride == 1 and infilters == filters: + x = layers.Add(name=prefix + 'Add')([shortcut, x]) + return x + + +def MobileNetV3(stack_fn, + last_point_ch, + input_shape=None, + alpha=1.0, + model_type='large', + minimalistic=False, + include_top=True, + weights='imagenet', + input_tensor=None, + classes=1000, + pooling=None, + dropout_rate=0.2, + **kwargs): + """Instantiates the MobileNetV3 architecture. + + # Arguments + stack_fn: a function that returns output tensor for the + stacked residual blocks. + last_point_ch: number channels at the last layer (before top) + input_shape: optional shape tuple, to be specified if you would + like to use a model with an input img resolution that is not + (224, 224, 3). + It should have exactly 3 inputs channels (224, 224, 3). + You can also omit this option if you would like + to infer input_shape from an input_tensor. + If you choose to include both input_tensor and input_shape then + input_shape will be used if they match, if the shapes + do not match then we will throw an error. + E.g. `(160, 160, 3)` would be one valid value. + alpha: controls the width of the network. This is known as the + depth multiplier in the MobileNetV3 paper, but the name is kept for + consistency with MobileNetV1 in Keras. + - If `alpha` < 1.0, proportionally decreases the number + of filters in each layer. + - If `alpha` > 1.0, proportionally increases the number + of filters in each layer. + - If `alpha` = 1, default number of filters from the paper + are used at each layer. + model_type: MobileNetV3 is defined as two models: large and small. These + models are targeted at high and low resource use cases respectively. + minimalistic: In addition to large and small models this module also contains + so-called minimalistic models, these models have the same per-layer + dimensions characteristic as MobilenetV3 however, they don't utilize any + of the advanced blocks (squeeze-and-excite units, hard-swish, and 5x5 + convolutions). While these models are less efficient on CPU, they are + much more performant on GPU/DSP. + include_top: whether to include the fully-connected + layer at the top of the network. + weights: one of `None` (random initialization), + 'imagenet' (pre-training on ImageNet), + or the path to the weights file to be loaded. + input_tensor: optional Keras tensor (i.e. output of + `layers.Input()`) + to use as image input for the model. + classes: optional number of classes to classify images + into, only to be specified if `include_top` is True, and + if no `weights` argument is specified. + pooling: optional pooling mode for feature extraction + when `include_top` is `False`. + - `None` means that the output of the model will be + the 4D tensor output of the + last convolutional layer. + - `avg` means that global average pooling + will be applied to the output of the + last convolutional layer, and thus + the output of the model will be a 2D tensor. + - `max` means that global max pooling will + be applied. + dropout_rate: fraction of the input units to drop on the last layer + # Returns + A Keras model instance. + + # Raises + ValueError: in case of invalid model type, argument for `weights`, + or invalid input shape when weights='imagenet' + """ + global backend, layers, models, keras_utils + backend, layers, models, keras_utils = get_submodules_from_kwargs(kwargs) + + if not (weights in {'imagenet', None} or os.path.exists(weights)): + raise ValueError('The `weights` argument should be either ' + '`None` (random initialization), `imagenet` ' + '(pre-training on ImageNet), ' + 'or the path to the weights file to be loaded.') + + if weights == 'imagenet' and include_top and classes != 1000: + raise ValueError('If using `weights` as `"imagenet"` with `include_top` ' + 'as true, `classes` should be 1000') + + # Determine proper input shape and default size. + # If both input_shape and input_tensor are used, they should match + if input_shape is not None and input_tensor is not None: + try: + is_input_t_tensor = backend.is_keras_tensor(input_tensor) + except ValueError: + try: + is_input_t_tensor = backend.is_keras_tensor( + keras_utils.get_source_inputs(input_tensor)) + except ValueError: + raise ValueError('input_tensor: ', input_tensor, + 'is not type input_tensor') + if is_input_t_tensor: + if backend.image_data_format == 'channels_first': + if backend.int_shape(input_tensor)[1] != input_shape[1]: + raise ValueError('input_shape: ', input_shape, + 'and input_tensor: ', input_tensor, + 'do not meet the same shape requirements') + else: + if backend.int_shape(input_tensor)[2] != input_shape[1]: + raise ValueError('input_shape: ', input_shape, + 'and input_tensor: ', input_tensor, + 'do not meet the same shape requirements') + else: + raise ValueError('input_tensor specified: ', input_tensor, + 'is not a keras tensor') + + # If input_shape is None, infer shape from input_tensor + if input_shape is None and input_tensor is not None: + + try: + backend.is_keras_tensor(input_tensor) + except ValueError: + raise ValueError('input_tensor: ', input_tensor, + 'is type: ', type(input_tensor), + 'which is not a valid type') + + if backend.is_keras_tensor(input_tensor): + if backend.image_data_format() == 'channels_first': + rows = backend.int_shape(input_tensor)[2] + cols = backend.int_shape(input_tensor)[3] + input_shape = (3, cols, rows) + else: + rows = backend.int_shape(input_tensor)[1] + cols = backend.int_shape(input_tensor)[2] + input_shape = (cols, rows, 3) + # If input_shape is None and input_tensor is None using standart shape + if input_shape is None and input_tensor is None: + input_shape = (None, None, 3) + + if backend.image_data_format() == 'channels_last': + row_axis, col_axis = (0, 1) + else: + row_axis, col_axis = (1, 2) + rows = input_shape[row_axis] + cols = input_shape[col_axis] + if rows and cols and (rows < 32 or cols < 32): + raise ValueError('Input size must be at least 32x32; got `input_shape=' + + str(input_shape) + '`') + if weights == 'imagenet': + if minimalistic is False and alpha not in [0.75, 1.0] \ + or minimalistic is True and alpha != 1.0: + raise ValueError('If imagenet weights are being loaded, ' + 'alpha can be one of `0.75`, `1.0` for non minimalistic' + ' or `1.0` for minimalistic only.') + + if rows != cols or rows != 224: + warnings.warn('`input_shape` is undefined or non-square, ' + 'or `rows` is not 224.' + ' Weights for input shape (224, 224) will be' + ' loaded as the default.') + + if input_tensor is None: + img_input = layers.Input(shape=input_shape) + else: + if not backend.is_keras_tensor(input_tensor): + img_input = layers.Input(tensor=input_tensor, shape=input_shape) + else: + img_input = input_tensor + + channel_axis = 1 if backend.image_data_format() == 'channels_first' else -1 + + if minimalistic: + kernel = 3 + activation = relu + se_ratio = None + else: + kernel = 5 + activation = hard_swish + se_ratio = 0.25 + + x = layers.ZeroPadding2D(padding=correct_pad(backend, img_input, 3), + name='Conv_pad')(img_input) + x = layers.Conv2D(16, + kernel_size=3, + strides=(2, 2), + padding='valid', + use_bias=False, + name='Conv')(x) + x = layers.BatchNormalization(axis=channel_axis, + epsilon=1e-3, + momentum=0.999, + name='Conv/BatchNorm')(x) + x = layers.Activation(activation)(x) + + x = stack_fn(x, kernel, activation, se_ratio) + + last_conv_ch = _depth(backend.int_shape(x)[channel_axis] * 6) + + # if the width multiplier is greater than 1 we + # increase the number of output channels + if alpha > 1.0: + last_point_ch = _depth(last_point_ch * alpha) + + x = layers.Conv2D(last_conv_ch, + kernel_size=1, + padding='same', + use_bias=False, + name='Conv_1')(x) + x = layers.BatchNormalization(axis=channel_axis, + epsilon=1e-3, + momentum=0.999, + name='Conv_1/BatchNorm')(x) + x = layers.Activation(activation)(x) + + if include_top: + x = layers.GlobalAveragePooling2D()(x) + if channel_axis == 1: + x = layers.Reshape((last_conv_ch, 1, 1))(x) + else: + x = layers.Reshape((1, 1, last_conv_ch))(x) + x = layers.Conv2D(last_point_ch, + kernel_size=1, + padding='same', + name='Conv_2')(x) + x = layers.Activation(activation)(x) + if dropout_rate > 0: + x = layers.Dropout(dropout_rate)(x) + x = layers.Conv2D(classes, + kernel_size=1, + padding='same', + name='Logits')(x) + x = layers.Flatten()(x) + x = layers.Softmax(name='Predictions/Softmax')(x) + else: + if pooling == 'avg': + x = layers.GlobalAveragePooling2D(name='avg_pool')(x) + elif pooling == 'max': + x = layers.GlobalMaxPooling2D(name='max_pool')(x) + # Ensure that the model takes into account + # any potential predecessors of `input_tensor`. + if input_tensor is not None: + inputs = keras_utils.get_source_inputs(input_tensor) + else: + inputs = img_input + + # Create model. + model = models.Model(inputs, x, name='MobilenetV3' + model_type) + + # Load weights. + if weights == 'imagenet': + model_name = "{}{}_224_{}_float".format( + model_type, '_minimalistic' if minimalistic else '', str(alpha)) + if include_top: + file_name = 'weights_mobilenet_v3_' + model_name + '.h5' + file_hash = WEIGHTS_HASHES[model_name][0] + else: + file_name = 'weights_mobilenet_v3_' + model_name + '_no_top.h5' + file_hash = WEIGHTS_HASHES[model_name][1] + weights_path = keras_utils.get_file(file_name, + BASE_WEIGHT_PATH + file_name, + cache_subdir='models', + file_hash=file_hash) + model.load_weights(weights_path) + elif weights is not None: + model.load_weights(weights) + + return model + + +def MobileNetV3Small(input_shape=None, + alpha=1.0, + minimalistic=False, + include_top=True, + weights='imagenet', + input_tensor=None, + classes=1000, + pooling=None, + dropout_rate=0.2, + **kwargs): + def stack_fn(x, kernel, activation, se_ratio): + def depth(d): + return _depth(d * alpha) + x = _inverted_res_block(x, 1, depth(16), 3, 2, se_ratio, relu, 0) + x = _inverted_res_block(x, 72. / 16, depth(24), 3, 2, None, relu, 1) + x = _inverted_res_block(x, 88. / 24, depth(24), 3, 1, None, relu, 2) + x = _inverted_res_block(x, 4, depth(40), kernel, 2, se_ratio, activation, 3) + x = _inverted_res_block(x, 6, depth(40), kernel, 1, se_ratio, activation, 4) + x = _inverted_res_block(x, 6, depth(40), kernel, 1, se_ratio, activation, 5) + x = _inverted_res_block(x, 3, depth(48), kernel, 1, se_ratio, activation, 6) + x = _inverted_res_block(x, 3, depth(48), kernel, 1, se_ratio, activation, 7) + x = _inverted_res_block(x, 6, depth(96), kernel, 2, se_ratio, activation, 8) + x = _inverted_res_block(x, 6, depth(96), kernel, 1, se_ratio, activation, 9) + x = _inverted_res_block(x, 6, depth(96), kernel, 1, se_ratio, activation, 10) + return x + return MobileNetV3(stack_fn, + 1024, + input_shape, + alpha, + 'small', + minimalistic, + include_top, + weights, + input_tensor, + classes, + pooling, + dropout_rate, + **kwargs) + + +def MobileNetV3Large(input_shape=None, + alpha=1.0, + minimalistic=False, + include_top=True, + weights='imagenet', + input_tensor=None, + classes=1000, + pooling=None, + dropout_rate=0.2, + **kwargs): + def stack_fn(x, kernel, activation, se_ratio): + def depth(d): + return _depth(d * alpha) + x = _inverted_res_block(x, 1, depth(16), 3, 1, None, relu, 0) + x = _inverted_res_block(x, 4, depth(24), 3, 2, None, relu, 1) + x = _inverted_res_block(x, 3, depth(24), 3, 1, None, relu, 2) + x = _inverted_res_block(x, 3, depth(40), kernel, 2, se_ratio, relu, 3) + x = _inverted_res_block(x, 3, depth(40), kernel, 1, se_ratio, relu, 4) + x = _inverted_res_block(x, 3, depth(40), kernel, 1, se_ratio, relu, 5) + x = _inverted_res_block(x, 6, depth(80), 3, 2, None, activation, 6) + x = _inverted_res_block(x, 2.5, depth(80), 3, 1, None, activation, 7) + x = _inverted_res_block(x, 2.3, depth(80), 3, 1, None, activation, 8) + x = _inverted_res_block(x, 2.3, depth(80), 3, 1, None, activation, 9) + x = _inverted_res_block(x, 6, depth(112), 3, 1, se_ratio, activation, 10) + x = _inverted_res_block(x, 6, depth(112), 3, 1, se_ratio, activation, 11) + x = _inverted_res_block(x, 6, depth(160), kernel, 2, se_ratio, + activation, 12) + x = _inverted_res_block(x, 6, depth(160), kernel, 1, se_ratio, + activation, 13) + x = _inverted_res_block(x, 6, depth(160), kernel, 1, se_ratio, + activation, 14) + return x + return MobileNetV3(stack_fn, + 1280, + input_shape, + alpha, + 'large', + minimalistic, + include_top, + weights, + input_tensor, + classes, + pooling, + dropout_rate, + **kwargs) + + +setattr(MobileNetV3Small, '__doc__', MobileNetV3.__doc__) +setattr(MobileNetV3Large, '__doc__', MobileNetV3.__doc__) diff --git a/tests/applications_test.py b/tests/applications_test.py index 1907255..cbbffc4 100644 --- a/tests/applications_test.py +++ b/tests/applications_test.py @@ -40,7 +40,8 @@ def wrapper(*args, **kwargs): for (name, module) in [('resnet', keras_applications.resnet), ('resnet_v2', keras_applications.resnet_v2), ('resnext', keras_applications.resnext), - ('efficientnet', keras_applications.efficientnet)]: + ('efficientnet', keras_applications.efficientnet), + ('mobilenet_v3', keras_applications.mobilenet_v3)]: module.decode_predictions = keras_modules_injection(module.decode_predictions) module.preprocess_input = keras_modules_injection(module.preprocess_input) for app in dir(module): @@ -58,7 +59,11 @@ def wrapper(*args, **kwargs): RESNEXT_LIST = [keras_applications.resnext.ResNeXt50, keras_applications.resnext.ResNeXt101] MOBILENET_LIST = [(mobilenet.MobileNet, mobilenet, 1024), - (mobilenet_v2.MobileNetV2, mobilenet_v2, 1280)] + (mobilenet_v2.MobileNetV2, mobilenet_v2, 1280), + (keras_applications.mobilenet_v3.MobileNetV3Small, + keras_applications.mobilenet_v3, 576), + (keras_applications.mobilenet_v3.MobileNetV3Large, + keras_applications.mobilenet_v3, 960)] DENSENET_LIST = [(densenet.DenseNet121, 1024), (densenet.DenseNet169, 1664), (densenet.DenseNet201, 1920)]