Skip to content
This repository has been archived by the owner on Nov 3, 2022. It is now read-only.

Commit

Permalink
Fix the bottleneck of ResNet50 reported in #5
Browse files Browse the repository at this point in the history
  • Loading branch information
taehoonlee committed Jun 4, 2018
1 parent 4be176a commit 4b18156
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions keras_applications/resnet50.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,8 @@ def ResNet50(include_top=True,
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')

x = layers.AveragePooling2D((7, 7), name='avg_pool')(x)

if include_top:
x = layers.AveragePooling2D((7, 7), name='avg_pool')(x)
x = layers.Flatten()(x)
x = layers.Dense(classes, activation='softmax', name='fc1000')(x)
else:
Expand Down

4 comments on commit 4b18156

@ahundt
Copy link

@ahundt ahundt commented on 4b18156 Jun 4, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be an invisible breaking change for quite a few user models, including some of mine which will suddenly access something in the last identity block without warning or error... Are you sure about just moving it and committing it to master?

Perhaps consider printing a warning about this change for a few versions in the else case of if include_top:?

@taehoonlee
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the comment, @ahundt. I think that this commit is better to be called "a correction" rather than "a breaking change". When using app(weights='imagenet', include_top=False), users expect 16x or 32x down-sampled feature maps and all the apps except ResNet50 return a 32x map (7x7 for a 224x224 input) correctly. The current ResNet50 results in a 224x down-sampled map (1x1) which is apparently weird. Anyway I think the idea printing a warning is good as you pointed out.

Please note that the results of model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3)); model.summary() are:

Before this commit:

add_16 (Add)                    (None, 7, 7, 2048)   0           bn5c_branch2c[0][0]              
                                                                 activation_46[0][0]              
__________________________________________________________________________________________________
activation_49 (Activation)      (None, 7, 7, 2048)   0           add_16[0][0]                     
__________________________________________________________________________________________________
avg_pool (AveragePooling2D)     (None, 1, 1, 2048)   0           activation_49[0][0]              
==================================================================================================

After this commit:

add_16 (Add)                    (None, 7, 7, 2048)   0           bn5c_branch2c[0][0]              
                                                                 activation_46[0][0]              
__________________________________________________________________________________________________
activation_49 (Activation)      (None, 7, 7, 2048)   0           add_16[0][0]                     
==================================================================================================

@akbargumbira
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite get it why users would expect 16x or 32x downsampling. The architecture is the author's legacy and it could be anything. What would happen when there is a new competitive architecture and it doesn't follow that number?

I do think that breaking changes should be at least warned or in the release notes

@taehoonlee
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the comments, @akbargumbira.

  • The argument include_top=False is designed for feature extraction as described in the official docs. The feature does not have to be always 16x (14x14) or 32x (7x7) down-sampled feature maps. But the 224x down-sampled map (1x1) is apparently weird.
  • Please see the warning message.

Please sign in to comment.