forked from willard-yuan/flask-keras-cnn-image-retrieval
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathextract_cnn_vgg16_keras.py
33 lines (29 loc) · 1.17 KB
/
extract_cnn_vgg16_keras.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# -*- coding: utf-8 -*-
# Author: yongyuan.name
import numpy as np
from numpy import linalg as LA
from keras.applications.vgg16 import VGG16
from keras.preprocessing import image
from keras.applications.vgg16 import preprocess_input
class VGGNet:
def __init__(self):
# weights: 'imagenet'
# pooling: 'max' or 'avg'
# input_shape: (width, height, 3), width and height should >= 48
self.input_shape = (224, 224, 3)
self.weight = 'imagenet'
self.pooling = 'max'
self.model = VGG16(weights = self.weight, input_shape = (self.input_shape[0], self.input_shape[1], self.input_shape[2]), pooling = self.pooling, include_top = False)
self.model.predict(np.zeros((1, 224, 224 , 3)))
'''
Use vgg16 model to extract features
Output normalized feature vector
'''
def extract_feat(self, img_path):
img = image.load_img(img_path, target_size=(self.input_shape[0], self.input_shape[1]))
img = image.img_to_array(img)
img = np.expand_dims(img, axis=0)
img = preprocess_input(img)
feat = self.model.predict(img)
norm_feat = feat[0]/LA.norm(feat[0])
return norm_feat