-
Notifications
You must be signed in to change notification settings - Fork 23
/
Copy pathsiamese.py
105 lines (88 loc) · 4.4 KB
/
siamese.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import matplotlib.pyplot as plt
import numpy as np
from keras import backend as K
from PIL import Image
from nets.siamese import siamese
from utils.utils import (cvtColor, letterbox_image, preprocess_input,
show_config)
#---------------------------------------------------#
# 使用自己训练好的模型预测需要修改model_path参数
#---------------------------------------------------#
class Siamese(object):
_defaults = {
#-----------------------------------------------------#
# 使用自己训练好的模型进行预测一定要修改model_path
# model_path指向logs文件夹下的权值文件
#-----------------------------------------------------#
"model_path" : 'model_data/Omniglot_vgg.h5',
#-----------------------------------------------------#
# 输入图片的大小。
#-----------------------------------------------------#
"input_shape" : [105, 105],
#--------------------------------------------------------------------#
# 该变量用于控制是否使用letterbox_image对输入图像进行不失真的resize
# 否则对图像进行CenterCrop
#--------------------------------------------------------------------#
"letterbox_image" : False,
}
@classmethod
def get_defaults(cls, n):
if n in cls._defaults:
return cls._defaults[n]
else:
return "Unrecognized attribute name '" + n + "'"
#---------------------------------------------------#
# 初始化Siamese
#---------------------------------------------------#
def __init__(self, **kwargs):
self.__dict__.update(self._defaults)
for name, value in kwargs.items():
setattr(self, name, value)
self.sess = K.get_session()
self.generate()
show_config(**self._defaults)
#---------------------------------------------------#
# 载入模型
#---------------------------------------------------#
def generate(self):
model_path = os.path.expanduser(self.model_path)
assert model_path.endswith('.h5'), 'Keras model or weights must be a .h5 file.'
#---------------------------#
# 载入模型与权值
#---------------------------#
self.model = siamese([self.input_shape[0], self.input_shape[1], 3])
self.model.load_weights(self.model_path)
print('{} model loaded.'.format(model_path))
#---------------------------------------------------#
# 检测图片
#---------------------------------------------------#
def detect_image(self, image_1, image_2):
#---------------------------------------------------------#
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
#---------------------------------------------------------#
image_1 = cvtColor(image_1)
image_2 = cvtColor(image_2)
#---------------------------------------------------#
# 对输入图像进行不失真的resize
#---------------------------------------------------#
image_1 = letterbox_image(image_1, [self.input_shape[1], self.input_shape[0]], self.letterbox_image)
image_2 = letterbox_image(image_2, [self.input_shape[1], self.input_shape[0]], self.letterbox_image)
#---------------------------------------------------------#
# 归一化+添加上batch_size维度
#---------------------------------------------------------#
photo1 = np.expand_dims(preprocess_input(np.array(image_1, np.float32)), 0)
photo2 = np.expand_dims(preprocess_input(np.array(image_2, np.float32)), 0)
#---------------------------------------------------#
# 获得预测结果,output输出为概率
#---------------------------------------------------#
output = self.model.predict([photo1, photo2])[0]
plt.subplot(1, 2, 1)
plt.imshow(np.array(image_1))
plt.subplot(1, 2, 2)
plt.imshow(np.array(image_2))
plt.text(-12, -12, 'Similarity:%.3f' % output, ha='center', va= 'bottom',fontsize=11)
plt.show()
return output
def close_session(self):
self.sess.close()