-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
214 lines (186 loc) · 8.71 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
# --conding: utf-8 --
import numpy as np
import mxnet as mx
from data_loader import get_iterators
from utils import *
from model import SSD,sizes_list,ratios_list
import time
from mxnet.ndarray.contrib import MultiBoxTarget, MultiBoxPrior
from mxnet import ndarray as nd
from mxnet import gluon
data_shape = (3,512,512)
std = np.array([61.04467501, 60.03631381, 60.7750983 ])
rgb_mean = np.array([ 130.063048, 129.967301, 124.410760])
ctx = mx.gpu(0)
resize = data_shape[1:]
rec_prefix = './dataset/data/rec/img_'+str(resize[0])+'_'+str(resize[1])
# num_class = 1
'''
loss define
'''
class FocalLoss(gluon.loss.Loss):
def __init__(self,axis=-1,alpha=0.25,gamma=2,batch_axis=0,**kwargs):
super(FocalLoss,self).__init__(None,batch_axis,**kwargs)
self.alpha = alpha
self.gamma = gamma
self.axis = axis
self.batch_axis = batch_axis
def hybrid_forward(self, F, y, label):
y = F.softmax(y)
py = y.pick(label, axis=self.axis, keepdims=True)
loss = - (self.alpha * ((1 - py) ** self.gamma)) * py.log()
return loss.mean(axis=self.batch_axis, exclude=True)
class SmoothL1Loss(gluon.loss.Loss):
def __init__(self,batch_axis=0,**kwargs):
super(SmoothL1Loss,self).__init__(None,batch_axis,**kwargs)
self.batch_axis = batch_axis
def hybrid_forward(self, F, y,label,mask):
loss = F.smooth_l1((y-label)*mask,scalar=1.0)
return nd.mean(loss,axis=self.batch_axis,exclude=True)
lossdoc='''
使用AP分数作为分类评价的标准。
由于在模型检测问题中,反例占据了绝大多数,即使把所有的边框全部预测为反例已然会具有不错的精度。
因此不能直接使用分类精度作为评价标准。
AP曲线考虑在预测为正例的标签中真正为正例的概率(查准率, precise)
以及在全部正例中预测为正例的概率(召回率, recall),更能反映模型的正确性。
使用MAE(平均绝对值误差)作为回归评价的标准。
'''
from mxnet import metric
from mxnet import autograd
from mxnet.ndarray.contrib import MultiBoxDetection
import numpy as np
'''
trian net
'''
def evaluate_acc(net,data_iter,ctx):
data_iter.reset()
box_metric = metric.MAE()
outs,labels = None,None
for i, batch in enumerate(data_iter):
data = batch.data[0].as_in_context(ctx)
label = batch.label[0].as_in_context(ctx)
# print('acc',label.shape)
anchors,box_preds,cls_preds = net(data)
#MultiBoxTraget 作用是将生成的anchors与哪些ground truth对应,提取出anchors的偏移和对应的类型
#预测的误差是每次网络输出的预测框g与anchors的差分别/anchor[xywh],然后作为smoothL1(label-g)解算,g才是预测
# 正负样本比例1:3
box_offset,box_mask,cls_labels=MultiBoxTarget(anchors,label,cls_preds.transpose((0,2,1)),
negative_mining_ratio=3.0)
box_metric.update([box_offset],[box_preds*box_mask])
cls_probs = nd.SoftmaxActivation(cls_preds.transpose((0,2,1)),mode='channel')
#对输出的bbox通过NMS极大值抑制算法筛选检测框
out = MultiBoxDetection(cls_probs,box_preds,anchors,force_suppress=True, clip=False, nms_threshold=0.45)
if outs is None:
outs = out
labels = label
else:
outs = nd.concat(outs,out,dim=0)
labels = nd.concat(labels,label,dim=0)
AP = evaluate_MAP(outs,labels)
return AP,box_metric
info = {"train_ap": [], "valid_ap": [], "loss": []}
def plot(key):
plt.plot(range(len(info[key])), info[key], label=key)
def mytrain(net,train_data,valid_data,ctx,start_epoch, end_epoch, cls_loss,box_loss,trainer=None):
if trainer is None:
# trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.01,'momentum':0.9, 'wd':5e-1})
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.1, 'wd':1e-3})
box_metric = metric.MAE()
for e in range(start_epoch, end_epoch):
# print(e)
train_data.reset()
valid_data.reset()
box_metric.reset()
tic = time.time()
_loss = [0, 0]
if e == 100 or e == 120 or e == 150 or e == 180 or e==200:
trainer.set_learning_rate(trainer.learning_rate * 0.2)
outs, labels = None, None
for i, batch in enumerate(train_data):
data = batch.data[0].as_in_context(ctx)
label = batch.label[0].as_in_context(ctx)
# print(label.shape)
with autograd.record():
anchors, box_preds, cls_preds = net(data)
# print(anchors.shape,box_preds.shape,cls_preds.shape)
# negative_mining_ratio,在生成的mask中增加*3的反例参加loss的计算。
box_offset, box_mask, cls_labels = MultiBoxTarget(anchors, label, cls_preds.transpose(axes=(0, 2, 1)),
negative_mining_ratio=3.0) # , overlap_threshold=0.75)
loss1 = cls_loss(cls_preds, cls_labels)
loss2 = box_loss(box_preds, box_offset, box_mask)
loss = loss1 + loss2
# print(loss1.shape,loss2.shape)
loss.backward()
trainer.step(data.shape[0])
_loss[0] += nd.mean(loss1).asscalar()
_loss[1] += nd.mean(loss2).asscalar()
cls_probs = nd.SoftmaxActivation(cls_preds.transpose((0, 2, 1)), mode='channel')
out = MultiBoxDetection(cls_probs, box_preds, anchors, force_suppress=True, clip=False, nms_threshold=0.45)
if outs is None:
outs = out
labels = label
else:
outs = nd.concat(outs, out, dim=0)
labels = nd.concat(labels, label, dim=0)
box_metric.update([box_offset], [box_preds * box_mask])
train_AP = evaluate_MAP(outs, labels)
valid_AP, val_box_metric = evaluate_acc(net,valid_data, ctx)
info["train_ap"].append(train_AP)
info["valid_ap"].append(valid_AP)
info["loss"].append(_loss)
if (e + 1) % 10 == 0:
print("epoch: %d time: %.2f loss: %.4f, %.4f lr: %.5f" % (
e, time.time() - tic, _loss[0], _loss[1], trainer.learning_rate))
print("train mae: %.4f AP: %.4f" % (box_metric.get()[1], train_AP))
print("valid mae: %.4f AP: %.4f" % (val_box_metric.get()[1], valid_AP))
if True:
info["loss"] = np.array(info["loss"])
info["cls_loss"] = info["loss"][:, 0]
info["box_loss"] = info["loss"][:, 1]
plt.figure(figsize=(12, 4))
plt.subplot(121)
plot("train_ap")
plot("valid_ap")
plt.legend(loc="upper right")
plt.subplot(122)
plot("cls_loss")
plot("box_loss")
plt.legend(loc="upper right")
plt.savefig('loss_curve.png')
if __name__ == '__main__':
batch_size = 4
#1. get dataset and show
train_data,valid_data,class_names,num_classes = get_iterators(rec_prefix,data_shape,batch_size)
# train_data.reset()
##label数量需要大于等于3
if train_data.next().label[0][0].shape[0] < 3:
train_data.reshape(label_shape=(3, 5))
valid_data.reshape(label_shape=(3, 5))
# valid_data.sync_label_shape(train_data)
if False:
batch = train_data.next()
images = batch.data[0][:].as_in_context(mx.gpu(0))
labels = batch.label[0][:].as_in_context(mx.gpu(0))
show_images(images.asnumpy(),labels.asnumpy(),rgb_mean,std,show_text=True,fontsize=6,MN=(2,4))
print(labels.shape)
#2. net initialize
net = SSD(1,verbose=False,prefix='ssd_')
# net.hybridize() # MultiBoxPrior cannot support symbol
# print(net)
# tic = time.time()
# anchors,box_preds,cls_preds = net(images)
# print(time.time()-tic)
# print(net)
#MultiBoxTraget 作用是将生成的anchors与哪些ground truth对应,提取出anchors的偏移和对应的类型
#预测的误差是每次网络的预测框g与anchors的差分别/anchor[xywh],然后作为smoothL1(label-g)解算,g才是预测
# box_offset,box_mask,cls_labels = MultiBoxTarget(anchors,batch.label[0],cls_preds)
# box_offset, box_mask, cls_labels = MultiBoxTarget(anchors, batch.label[0].as_in_context(mx.gpu(0)),
# cls_preds.transpose((0, 2, 1)))
#3. loss define
cls_loss = FocalLoss() # predict
box_loss = SmoothL1Loss() # regression
#4. train
mytrain(net, train_data,valid_data,ctx, 0, 220, cls_loss, box_loss)
mkdir_if_not_exist("./Model")
# net.save_params("./Model/mobilenet1.0_papercupDetect.param")
net.save_params("./Model/resnet18_papercupDetect.param")