-
Notifications
You must be signed in to change notification settings - Fork 0
/
func_utils.py
100 lines (89 loc) · 3.86 KB
/
func_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import torch
import numpy as np
from datasets.dotadevkit.dotadevkit.ops.ResultMerge import py_cpu_nms_poly_fast, py_cpu_nms_poly
def decode_prediction(predictions, dsets, args, img_id, down_ratio):
predictions = predictions[0, :, :]
ori_image = dsets.load_image(dsets.img_ids.index(img_id))
h, w, c = ori_image.shape
pts0 = {cat: [] for cat in dsets.category}
scores0 = {cat: [] for cat in dsets.category}
for pred in predictions:
cen_pt = np.asarray([pred[0], pred[1]], np.float32)
tt = np.asarray([pred[2], pred[3]], np.float32)
rr = np.asarray([pred[4], pred[5]], np.float32)
bb = np.asarray([pred[6], pred[7]], np.float32)
ll = np.asarray([pred[8], pred[9]], np.float32)
tl = tt + ll - cen_pt
bl = bb + ll - cen_pt
tr = tt + rr - cen_pt
br = bb + rr - cen_pt
score = pred[10]
clse = pred[11]
pts = np.asarray([tr, br, bl, tl], np.float32)
pts[:, 0] = pts[:, 0] * down_ratio / args.input_w * w
pts[:, 1] = pts[:, 1] * down_ratio / args.input_h * h
pts0[dsets.category[int(clse)]].append(pts)
scores0[dsets.category[int(clse)]].append(score)
return pts0, scores0
def non_maximum_suppression(pts, scores):
nms_item = np.concatenate([pts[:, 0:1, 0],
pts[:, 0:1, 1],
pts[:, 1:2, 0],
pts[:, 1:2, 1],
pts[:, 2:3, 0],
pts[:, 2:3, 1],
pts[:, 3:4, 0],
pts[:, 3:4, 1],
scores[:, np.newaxis]], axis=1)
nms_item = np.asarray(nms_item, np.float64)
keep_index = py_cpu_nms_poly_fast(dets=nms_item, thresh=0.1)
return nms_item[keep_index]
def write_results(args,
model,
dsets,
down_ratio,
device,
decoder,
result_path,
print_ps=False):
results = {cat: {img_id: [] for img_id in dsets.img_ids} for cat in dsets.category}
for index in range(len(dsets)):
data_dict = dsets.__getitem__(index)
image = data_dict['image'].to(device)
img_id = data_dict['img_id']
image_w = data_dict['image_w']
image_h = data_dict['image_h']
with torch.no_grad():
pr_decs = model(image)
decoded_pts = []
decoded_scores = []
torch.cuda.synchronize(device)
predictions = decoder.ctdet_decode(pr_decs)
pts0, scores0 = decode_prediction(predictions, dsets, args, img_id, down_ratio)
decoded_pts.append(pts0)
decoded_scores.append(scores0)
# nms
for cat in dsets.category:
if cat == 'background':
continue
pts_cat = []
scores_cat = []
for pts0, scores0 in zip(decoded_pts, decoded_scores):
pts_cat.extend(pts0[cat])
scores_cat.extend(scores0[cat])
pts_cat = np.asarray(pts_cat, np.float32)
scores_cat = np.asarray(scores_cat, np.float32)
if pts_cat.shape[0]:
nms_results = non_maximum_suppression(pts_cat, scores_cat)
results[cat][img_id].extend(nms_results)
if print_ps:
print('testing {}/{} data {}'.format(index+1, len(dsets), img_id))
for cat in dsets.category:
if cat == 'background':
continue
with open(os.path.join(result_path, 'Task1_{}.txt'.format(cat)), 'w') as f:
for img_id in results[cat]:
for pt in results[cat][img_id]:
f.write('{} {:.12f} {:.1f} {:.1f} {:.1f} {:.1f} {:.1f} {:.1f} {:.1f} {:.1f}\n'.format(
img_id, pt[8], pt[0], pt[1], pt[2], pt[3], pt[4], pt[5], pt[6], pt[7]))