Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I updated CBAM-keras Code for tf 2.0 #5

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@
"""

from __future__ import print_function
import keras
from keras.layers import Dense, Conv2D, BatchNormalization, Activation
from keras.layers import AveragePooling2D, Input, Flatten
from keras.optimizers import Adam
from keras.callbacks import ModelCheckpoint, LearningRateScheduler
from keras.callbacks import ReduceLROnPlateau
from keras.preprocessing.image import ImageDataGenerator
from keras.regularizers import l2
from keras import backend as K
from keras.models import Model
from keras.datasets import cifar10
import tensorflow.keras
from tensorflow.keras.layers import Dense, Conv2D, BatchNormalization, Activation
from tensorflow.keras.layers import AveragePooling2D, Input, Flatten
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler
from tensorflow.keras.callbacks import ReduceLROnPlateau
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.regularizers import l2
from tensorflow.keras import backend as K
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import cifar10
from models import resnext, resnet_v1, resnet_v2, mobilenets, inception_v3, inception_resnet_v2, densenet
from utils import lr_schedule
import numpy as np
Expand Down Expand Up @@ -53,17 +53,17 @@
print('y_train shape:', y_train.shape)

# Convert class vectors to binary class matrices.
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
y_train = tensorflow.keras.utils.to_categorical(y_train, num_classes)
y_test = tensorflow.keras.utils.to_categorical(y_test, num_classes)

depth = 20 # For ResNet, specify the depth (e.g. ResNet50: depth=50)
model = resnet_v1.resnet_v1(input_shape=input_shape, depth=depth, attention_module=attention_module)
# model = resnet_v1.resnet_v1(input_shape=input_shape, depth=depth, attention_module=attention_module)
# model = resnet_v2.resnet_v2(input_shape=input_shape, depth=depth, attention_module=attention_module)
# model = resnext.ResNext(input_shape=input_shape, classes=num_classes, attention_module=attention_module)
# model = mobilenets.MobileNet(input_shape=input_shape, classes=num_classes, attention_module=attention_module)
# model = inception_v3.InceptionV3(input_shape=input_shape, classes=num_classes, attention_module=attention_module)
# model = inception_resnet_v2.InceptionResNetV2(input_shape=input_shape, classes=num_classes, attention_module=attention_module)
# model = densenet.DenseNet(input_shape=input_shape, classes=num_classes, attention_module=attention_module)
model = densenet.DenseNet(input_shape=input_shape, classes=num_classes, attention_module=attention_module)

model.compile(loss='categorical_crossentropy',
optimizer=Adam(lr=lr_schedule(0)),
Expand Down
Binary file removed models/__pycache__/attention_module.cpython-36.pyc
Binary file not shown.
Binary file removed models/__pycache__/densenet.cpython-36.pyc
Binary file not shown.
Binary file not shown.
Binary file removed models/__pycache__/inception_v3.cpython-36.pyc
Binary file not shown.
Binary file removed models/__pycache__/mobilenets.cpython-36.pyc
Binary file not shown.
Binary file removed models/__pycache__/resnet_v1.cpython-36.pyc
Binary file not shown.
Binary file removed models/__pycache__/resnet_v2.cpython-36.pyc
Binary file not shown.
Binary file removed models/__pycache__/resnext.cpython-36.pyc
Binary file not shown.
Binary file removed models/__pycache__/se.cpython-36.pyc
Binary file not shown.
Binary file removed models/__pycache__/se_resnext.cpython-36.pyc
Binary file not shown.
40 changes: 20 additions & 20 deletions models/attention_module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from keras.layers import GlobalAveragePooling2D, GlobalMaxPooling2D, Reshape, Dense, multiply, Permute, Concatenate, Conv2D, Add, Activation, Lambda
from keras import backend as K
from keras.activations import sigmoid
from tensorflow.keras.layers import GlobalAveragePooling2D, GlobalMaxPooling2D, Reshape, Dense, multiply, Permute, Concatenate, Conv2D, Add, Activation, Lambda
from tensorflow.keras import backend as K
from tensorflow.keras.activations import sigmoid

def attach_attention_module(net, attention_module):
if attention_module == 'se_block': # SE_block
Expand All @@ -18,23 +18,23 @@ def se_block(input_feature, ratio=8):
"""

channel_axis = 1 if K.image_data_format() == "channels_first" else -1
channel = input_feature._keras_shape[channel_axis]
channel = input_feature.shape[channel_axis]

se_feature = GlobalAveragePooling2D()(input_feature)
se_feature = Reshape((1, 1, channel))(se_feature)
assert se_feature._keras_shape[1:] == (1,1,channel)
assert se_feature.shape[1:] == (1,1,channel)
se_feature = Dense(channel // ratio,
activation='relu',
kernel_initializer='he_normal',
use_bias=True,
bias_initializer='zeros')(se_feature)
assert se_feature._keras_shape[1:] == (1,1,channel//ratio)
assert se_feature.shape[1:] == (1,1,channel//ratio)
se_feature = Dense(channel,
activation='sigmoid',
kernel_initializer='he_normal',
use_bias=True,
bias_initializer='zeros')(se_feature)
assert se_feature._keras_shape[1:] == (1,1,channel)
assert se_feature.shape[1:] == (1,1,channel)
if K.image_data_format() == 'channels_first':
se_feature = Permute((3, 1, 2))(se_feature)

Expand All @@ -53,7 +53,7 @@ def cbam_block(cbam_feature, ratio=8):
def channel_attention(input_feature, ratio=8):

channel_axis = 1 if K.image_data_format() == "channels_first" else -1
channel = input_feature._keras_shape[channel_axis]
channel = input_feature.shape[channel_axis]

shared_layer_one = Dense(channel//ratio,
activation='relu',
Expand All @@ -67,19 +67,19 @@ def channel_attention(input_feature, ratio=8):

avg_pool = GlobalAveragePooling2D()(input_feature)
avg_pool = Reshape((1,1,channel))(avg_pool)
assert avg_pool._keras_shape[1:] == (1,1,channel)
assert avg_pool.shape[1:] == (1,1,channel)
avg_pool = shared_layer_one(avg_pool)
assert avg_pool._keras_shape[1:] == (1,1,channel//ratio)
assert avg_pool.shape[1:] == (1,1,channel//ratio)
avg_pool = shared_layer_two(avg_pool)
assert avg_pool._keras_shape[1:] == (1,1,channel)
assert avg_pool.shape[1:] == (1,1,channel)

max_pool = GlobalMaxPooling2D()(input_feature)
max_pool = Reshape((1,1,channel))(max_pool)
assert max_pool._keras_shape[1:] == (1,1,channel)
assert max_pool.shape[1:] == (1,1,channel)
max_pool = shared_layer_one(max_pool)
assert max_pool._keras_shape[1:] == (1,1,channel//ratio)
assert max_pool.shape[1:] == (1,1,channel//ratio)
max_pool = shared_layer_two(max_pool)
assert max_pool._keras_shape[1:] == (1,1,channel)
assert max_pool.shape[1:] == (1,1,channel)

cbam_feature = Add()([avg_pool,max_pool])
cbam_feature = Activation('sigmoid')(cbam_feature)
Expand All @@ -93,26 +93,26 @@ def spatial_attention(input_feature):
kernel_size = 7

if K.image_data_format() == "channels_first":
channel = input_feature._keras_shape[1]
channel = input_feature.shape[1]
cbam_feature = Permute((2,3,1))(input_feature)
else:
channel = input_feature._keras_shape[-1]
channel = input_feature.shape[-1]
cbam_feature = input_feature

avg_pool = Lambda(lambda x: K.mean(x, axis=3, keepdims=True))(cbam_feature)
assert avg_pool._keras_shape[-1] == 1
assert avg_pool.shape[-1] == 1
max_pool = Lambda(lambda x: K.max(x, axis=3, keepdims=True))(cbam_feature)
assert max_pool._keras_shape[-1] == 1
assert max_pool.shape[-1] == 1
concat = Concatenate(axis=3)([avg_pool, max_pool])
assert concat._keras_shape[-1] == 2
assert concat.shape[-1] == 2
cbam_feature = Conv2D(filters = 1,
kernel_size=kernel_size,
strides=1,
padding='same',
activation='sigmoid',
kernel_initializer='he_normal',
use_bias=False)(concat)
assert cbam_feature._keras_shape[-1] == 1
assert cbam_feature.shape[-1] == 1

if K.image_data_format() == "channels_first":
cbam_feature = Permute((3, 1, 2))(cbam_feature)
Expand Down
27 changes: 12 additions & 15 deletions models/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,18 @@

import warnings

from keras.models import Model
from keras.layers.core import Dense, Dropout, Activation, Reshape
from keras.layers.convolutional import Conv2D, Conv2DTranspose, UpSampling2D
from keras.layers.pooling import AveragePooling2D, MaxPooling2D
from keras.layers.pooling import GlobalAveragePooling2D
from keras.layers import Input
from keras.layers.merge import concatenate
from keras.layers.normalization import BatchNormalization
from keras.regularizers import l2
from keras.utils.layer_utils import convert_all_kernels_in_model, convert_dense_weights_data_format
from keras.utils.data_utils import get_file
from keras.engine.topology import get_source_inputs
from keras.applications.imagenet_utils import _obtain_input_shape
from keras.applications.imagenet_utils import decode_predictions
import keras.backend as K
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Dropout, Activation, Reshape
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, UpSampling2D
from tensorflow.keras.layers import AveragePooling2D, MaxPooling2D
from tensorflow.keras.layers import GlobalAveragePooling2D
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import concatenate
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.regularizers import l2
from tensorflow.keras.utils import get_source_inputs
from keras_applications.imagenet_utils import _obtain_input_shape
import tensorflow.keras.backend as K

from models.attention_module import attach_attention_module

Expand Down
35 changes: 17 additions & 18 deletions models/inception_resnet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,23 @@

import warnings

from keras.models import Model
from keras.layers import Activation
from keras.layers import AveragePooling2D
from keras.layers import BatchNormalization
from keras.layers import Concatenate
from keras.layers import Conv2D
from keras.layers import Dense
from keras.layers import GlobalAveragePooling2D
from keras.layers import GlobalMaxPooling2D
from keras.layers import Input
from keras.layers import Lambda
from keras.layers import MaxPooling2D
from keras.utils.data_utils import get_file
from keras.engine.topology import get_source_inputs
from keras.applications import imagenet_utils
from keras.applications.imagenet_utils import _obtain_input_shape
from keras.applications.imagenet_utils import decode_predictions
from keras import backend as K
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import AveragePooling2D
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import Concatenate
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import GlobalAveragePooling2D
from tensorflow.keras.layers import GlobalMaxPooling2D
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Lambda
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.keras.utils import get_file
from tensorflow.keras.utils import get_source_inputs
from tensorflow.keras.applications import imagenet_utils
from keras_applications.imagenet_utils import _obtain_input_shape
from tensorflow.keras import backend as K

from models.attention_module import attach_attention_module

Expand Down
33 changes: 16 additions & 17 deletions models/inception_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,22 @@

import warnings

from keras.models import Model
from keras import layers
from keras.layers import Activation
from keras.layers import Dense
from keras.layers import Reshape
from keras.layers import Input
from keras.layers import BatchNormalization
from keras.layers import Conv2D
from keras.layers import MaxPooling2D
from keras.layers import AveragePooling2D
from keras.layers import GlobalAveragePooling2D
from keras.layers import GlobalMaxPooling2D
from keras.engine.topology import get_source_inputs
from keras.utils.data_utils import get_file
from keras import backend as K
from keras.applications.imagenet_utils import decode_predictions
from keras.applications.imagenet_utils import _obtain_input_shape
from tensorflow.keras.models import Model
from tensorflow.keras import layers
from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Reshape
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.keras.layers import AveragePooling2D
from tensorflow.keras.layers import GlobalAveragePooling2D
from tensorflow.keras.layers import GlobalMaxPooling2D
from tensorflow.keras.utils import get_source_inputs
from tensorflow.keras.utils import get_file
from tensorflow.keras import backend as K
from keras_applications.imagenet_utils import _obtain_input_shape

from models.attention_module import attach_attention_module

Expand Down
39 changes: 19 additions & 20 deletions models/mobilenets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,25 @@

import warnings

from keras.models import Model
from keras.layers import Input
from keras.layers import Activation
from keras.layers import Dropout
from keras.layers import Reshape
from keras.layers import BatchNormalization
from keras.layers import GlobalAveragePooling2D
from keras.layers import GlobalMaxPooling2D
from keras.layers import Conv2D
from keras import initializers
from keras import regularizers
from keras import constraints
from keras.utils import conv_utils
from keras.utils.data_utils import get_file
from keras.engine.topology import get_source_inputs
from keras.engine import InputSpec
from keras.applications import imagenet_utils
from keras.applications.imagenet_utils import _obtain_input_shape
from keras.applications.imagenet_utils import decode_predictions
from keras import backend as K
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import Reshape
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import GlobalAveragePooling2D
from tensorflow.keras.layers import GlobalMaxPooling2D
from tensorflow.keras.layers import Conv2D
from tensorflow.keras import initializers
from tensorflow.keras import regularizers
from tensorflow.keras import constraints
import tensorflow.keras.utils as conv_utils
from tensorflow.keras.utils import get_file
from tensorflow.keras.utils import get_source_inputs
from tensorflow.keras.layers import InputSpec
from tensorflow.keras.applications import imagenet_utils
from keras_applications.imagenet_utils import _obtain_input_shape
from tensorflow.keras import backend as K

from models.attention_module import attach_attention_module

Expand Down
22 changes: 11 additions & 11 deletions models/resnet_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@
"""

from __future__ import print_function
import keras
from keras.layers import Dense, Conv2D, BatchNormalization, Activation
from keras.layers import AveragePooling2D, Input, Flatten
from keras.optimizers import Adam
from keras.callbacks import ModelCheckpoint, LearningRateScheduler
from keras.callbacks import ReduceLROnPlateau
from keras.preprocessing.image import ImageDataGenerator
from keras.regularizers import l2
from keras import backend as K
from keras.models import Model
import tensorflow.keras
from tensorflow.keras.layers import Dense, Conv2D, BatchNormalization, Activation
from tensorflow.keras.layers import AveragePooling2D, Input, Flatten
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler
from tensorflow.keras.callbacks import ReduceLROnPlateau
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.regularizers import l2
from tensorflow.keras import backend as K
from tensorflow.keras.models import Model
from models.attention_module import attach_attention_module

def resnet_layer(inputs,
Expand Down Expand Up @@ -124,7 +124,7 @@ def resnet_v1(input_shape, depth, num_classes=10, attention_module=None):
# attention_module
if attention_module is not None:
y = attach_attention_module(y, attention_module)
x = keras.layers.add([x, y])
x = tensorflow.keras.layers.add([x, y])
x = Activation('relu')(x)
num_filters *= 2

Expand Down
22 changes: 11 additions & 11 deletions models/resnet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@
"""

from __future__ import print_function
import keras
from keras.layers import Dense, Conv2D, BatchNormalization, Activation
from keras.layers import AveragePooling2D, Input, Flatten
from keras.optimizers import Adam
from keras.callbacks import ModelCheckpoint, LearningRateScheduler
from keras.callbacks import ReduceLROnPlateau
from keras.preprocessing.image import ImageDataGenerator
from keras.regularizers import l2
from keras import backend as K
from keras.models import Model
import tensorflow.keras
from tensorflow.keras.layers import Dense, Conv2D, BatchNormalization, Activation
from tensorflow.keras.layers import AveragePooling2D, Input, Flatten
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler
from tensorflow.keras.callbacks import ReduceLROnPlateau
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.regularizers import l2
from tensorflow.keras import backend as K
from tensorflow.keras.models import Model
from models.attention_module import attach_attention_module

def resnet_layer(inputs,
Expand Down Expand Up @@ -144,7 +144,7 @@ def resnet_v2(input_shape, depth, num_classes=10, attention_module=None):
if attention_module is not None:
y = attach_attention_module(y, attention_module)

x = keras.layers.add([x, y])
x = tensorflow.keras.layers.add([x, y])

num_filters_in = num_filters_out

Expand Down
Loading