Skip to content

Commit

Permalink
Merge pull request #18 from philipperemy/vgg16
Browse files Browse the repository at this point in the history
update

Former-commit-id: cf4a055
  • Loading branch information
Philippe Rémy authored Feb 4, 2019
2 parents 3ad1312 + d18df2f commit 2c33f19
Show file tree
Hide file tree
Showing 10 changed files with 76 additions and 48 deletions.
23 changes: 6 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ pip install keract
*You have just found a (easy) way to get the activations (outputs) and gradients for each layer of your Keras model (LSTM, conv nets...).*

<p align="center">
<img src="assets/1.png">
<img src="assets/intro.png">
</p>

## API
Expand Down Expand Up @@ -101,28 +101,17 @@ dense_2/Softmax:0
(128, 10)
```

We can even visualise some of them.
We can also visualise the activations. Here's another example using VGG16:

<p align="center">
<img src="assets/0.png" width="50">
<br><i>A random seven from MNIST</i>
<img src="assets/cat.jpg">
<br><i>A cat.</i>
</p>


<p align="center">
<img src="assets/1.png">
<br><i>Activation map of CONV1 of LeNet</i>
</p>

<p align="center">
<img src="assets/2.png" width="200">
<br><i>Activation map of FC1 of LeNet</i>
</p>


<p align="center">
<img src="assets/3.png" width="300">
<br><i>Activation map of Softmax of LeNet. <b>Yes it's a seven!</b></i>
<img src="assets/cat_activations.png" width="600">
<br><i>Outputs of the first conv of VGG16.</i>
</p>

### Repo views (since 2018/10/31)
Expand Down
Binary file removed assets/0.png
Binary file not shown.
Binary file removed assets/1.png
Binary file not shown.
Binary file removed assets/2.png
Binary file not shown.
Binary file removed assets/3.png
Binary file not shown.
Binary file added assets/cat.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/cat_activations.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/intro.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
51 changes: 51 additions & 0 deletions examples/vgg16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from keras.applications.vgg16 import VGG16
from keras.applications.vgg16 import decode_predictions
from keras.applications.vgg16 import preprocess_input
from keras.preprocessing.image import img_to_array
from keras.preprocessing.image import load_img

model = VGG16()

from PIL import Image
import requests
from io import BytesIO

url = 'https://upload.wikimedia.org/wikipedia/commons/thumb/1/14/Gatto_europeo4.jpg/250px-Gatto_europeo4.jpg'
response = requests.get(url)
image = Image.open(BytesIO(response.content))
image = image.crop((0, 0, 244, 244))
image.save('cat.jpg')

image = load_img('cat.jpg', target_size=(224, 224))
image = img_to_array(image)
image = image.reshape((1, image.shape[0], image.shape[1], image.shape[2]))
image = preprocess_input(image)
yhat = model.predict(image)
label = decode_predictions(yhat)
label = label[0][0]
print('{} ({})'.format(label[1], label[2] * 100))

import keract

model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
activations = keract.get_activations(model, image)
first = activations.get('block1_conv1/Relu:0')
keract.display_activations(activations)
#
# import matplotlib.pyplot as plt
#
# fig = plt.figure(figsize=(12, 12))
#
# rows = 8
# columns = 8
#
# first = activations.get('block1_conv1/Relu:0')
#
# for i in range(1, columns * rows + 1):
# img = first[0, :, :, i - 1]
# fig.add_subplot(rows, columns, i)
# plt.imshow(img)
# plt.axis('off')
# plt.show()
50 changes: 19 additions & 31 deletions keract/keract.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,36 +44,24 @@ def get_activations(model, x, layer_name=None):


def display_activations(activations):
import numpy as np
import matplotlib.pyplot as plt
"""
(1, 26, 26, 32)
(1, 24, 24, 64)
(1, 12, 12, 64)
(1, 12, 12, 64)
(1, 9216)
(1, 128)
(1, 128)
(1, 10)
"""
for name, activation_map in activations.items():
assert activation_map.shape[0] == 1, 'One image at a time to visualize.'
print('Displaying activation map [{}]'.format(name))
shape = activation_map.shape
if len(shape) == 4:
activations = np.hstack(np.transpose(activation_map[0], (2, 0, 1)))
elif len(shape) == 2:
# try to make it square as much as possible. we can skip some activations.
activations = activation_map[0]
num_activations = len(activations)
if num_activations > 1024: # too hard to display it on the screen.
square_param = int(np.floor(np.sqrt(num_activations)))
activations = activations[0: square_param * square_param]
activations = np.reshape(activations, (square_param, square_param))
else:
activations = np.expand_dims(activations, axis=0)
else:
raise Exception('len(shape) = 3 has not been implemented.')
plt.title(name)
plt.imshow(activations, interpolation='None', cmap='jet')
max_rows = 8
max_columns = 8
for layer_name, first in activations.items():
print(layer_name, first.shape, end=' ')
if first.shape[0] != 1:
print('-> Skipped. First dimension is not 1.')
continue
if len(first.shape) <= 2:
print('-> Skipped. 2D Activations.')
continue
print('')
fig = plt.figure(figsize=(12, 12))
plt.axis('off')
plt.title(layer_name)
for i in range(1, min(max_columns * max_rows + 1, first.shape[-1] + 1)):
img = first[0, :, :, i - 1]
fig.add_subplot(max_rows, max_columns, i)
plt.imshow(img)
plt.axis('off')
plt.show()

0 comments on commit 2c33f19

Please sign in to comment.