-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathlosses.py
193 lines (136 loc) · 7.69 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def calc_iou(a, b):
area = (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1])
iw = torch.min(torch.unsqueeze(a[:, 2], dim=1), b[:, 2]) - torch.max(torch.unsqueeze(a[:, 0], 1), b[:, 0])
ih = torch.min(torch.unsqueeze(a[:, 3], dim=1), b[:, 3]) - torch.max(torch.unsqueeze(a[:, 1], 1), b[:, 1])
iw = torch.clamp(iw, min=0)
ih = torch.clamp(ih, min=0)
ua = torch.unsqueeze((a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]), dim=1) + area - iw * ih
ua = torch.clamp(ua, min=1e-8)
intersection = iw * ih
IoU = intersection / ua
return IoU
def calc_iou_vis(a, b):
area = (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1])
iw = torch.min(torch.unsqueeze(a[:, 2], dim=1), b[:, 2]) - torch.max(torch.unsqueeze(a[:, 0], 1), b[:, 0])
ih = torch.min(torch.unsqueeze(a[:, 3], dim=1), b[:, 3]) - torch.max(torch.unsqueeze(a[:, 1], 1), b[:, 1])
iw = torch.clamp(iw, min=0)
ih = torch.clamp(ih, min=0)
intersection = iw * ih
IoU = intersection / area
return IoU
def IoG(box_a, box_b):
inter_xmin = torch.max(box_a[:, 0], box_b[:, 0])
inter_ymin = torch.max(box_a[:, 1], box_b[:, 1])
inter_xmax = torch.min(box_a[:, 2], box_b[:, 2])
inter_ymax = torch.min(box_a[:, 3], box_b[:, 3])
Iw = torch.clamp(inter_xmax - inter_xmin, min=0)
Ih = torch.clamp(inter_ymax - inter_ymin, min=0)
I = Iw * Ih
G = (box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1])
return I / G
class FocalLoss(nn.Module):
# def __init__(self):
def forward(self, classifications, regressions, anchors, bboxs, vboxs, ignores):
alpha = 0.25
gamma = 2.0
batch_size = classifications.shape[0]
classification_losses = []
regression_losses = []
anchor = anchors[0, :, :]
anchor_widths = anchor[:, 2] - anchor[:, 0]
anchor_heights = anchor[:, 3] - anchor[:, 1]
anchor_ctr_x = anchor[:, 0] + 0.5 * anchor_widths
anchor_ctr_y = anchor[:, 1] + 0.5 * anchor_heights
for j in range(batch_size):
classification = classifications[j, :, :]
regression = regressions[j, :, :]
bbox_annotation = bboxs[j, :, :]
bbox_annotation = bbox_annotation[bbox_annotation[:, 4] != -1]
vbox_annotation = vboxs[j, :, :]
vbox_annotation = vbox_annotation[vbox_annotation[:, 4] != -1]
if bbox_annotation.shape[0] == 0:
regression_losses.append(torch.tensor(0).float().cuda())
classification_losses.append(torch.tensor(0).float().cuda())
continue
classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4)
ignore = ignores[j, :, :]
ignore = ignore[ignore[:, 4] != -1]
if ignore.shape[0] > 0:
iou_igno = calc_iou(anchor, ignore)
iou_igno_max, iou_igno_argmax = torch.max(iou_igno, dim=1)
index_igno = torch.lt(iou_igno_max, 0.5)
anchor_keep = anchor[index_igno, :]
classification = classification[index_igno, :]
regression = regression[index_igno, :]
anchor_widths_keep = anchor_widths[index_igno]
anchor_heights_keep = anchor_heights[index_igno]
anchor_ctr_x_keep = anchor_ctr_x[index_igno]
anchor_ctr_y_keep = anchor_ctr_y[index_igno]
else:
anchor_keep = anchor
anchor_widths_keep = anchor_widths
anchor_heights_keep = anchor_heights
anchor_ctr_x_keep = anchor_ctr_x
anchor_ctr_y_keep = anchor_ctr_y
IoU = calc_iou(anchor_keep, bbox_annotation[:, :4]) # num_anchors x num_annotations
IoU_max, IoU_argmax = torch.max(IoU, dim=1) # num_anchors x 1
assigned_annotations_vis = vbox_annotation[IoU_argmax, :]
IoU_vis = IoG(assigned_annotations_vis, anchor_keep)
# compute the loss for classification
targets = torch.ones(classification.shape) * -1
targets = targets.cuda()
targets[torch.lt(IoU_max, 0.4), :] = 0
positive_ful = torch.ge(IoU_max, 0.5)
positive_vis = torch.ge(IoU_vis, 0.5)
positive_indices = positive_ful * positive_vis
num_positive_anchors = positive_indices.sum()
assigned_annotations = bbox_annotation[IoU_argmax, :]
targets[positive_indices, :] = 0
targets[positive_indices, assigned_annotations[positive_indices, 4].long()] = 1
alpha_factor = torch.ones(targets.shape).cuda() * alpha
alpha_factor = torch.where(torch.eq(targets, 1.), alpha_factor, 1. - alpha_factor)
focal_weight = torch.where(torch.eq(targets, 1.), 1. - classification, classification)
focal_weight = alpha_factor * torch.pow(focal_weight, gamma)
bce = -(targets * torch.log(classification) + (1.0 - targets) * torch.log(1.0 - classification))
# cls_loss = focal_weight * torch.pow(bce, gamma)
cls_loss = focal_weight * bce
cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, torch.zeros(cls_loss.shape).cuda())
classification_losses.append(cls_loss.sum() / torch.clamp(num_positive_anchors.float(), min=1.0))
# compute the loss for regression
if positive_indices.sum() > 0:
assigned_annotations = assigned_annotations[positive_indices, :]
anchor_widths_pi = anchor_widths_keep[positive_indices]
anchor_heights_pi = anchor_heights_keep[positive_indices]
anchor_ctr_x_pi = anchor_ctr_x_keep[positive_indices]
anchor_ctr_y_pi = anchor_ctr_y_keep[positive_indices]
gt_widths = assigned_annotations[:, 2] - assigned_annotations[:, 0]
gt_heights = assigned_annotations[:, 3] - assigned_annotations[:, 1]
gt_ctr_x = assigned_annotations[:, 0] + 0.5 * gt_widths
gt_ctr_y = assigned_annotations[:, 1] + 0.5 * gt_heights
# clip widths to 1
gt_widths = torch.clamp(gt_widths, min=1)
gt_heights = torch.clamp(gt_heights, min=1)
targets_dx = (gt_ctr_x - anchor_ctr_x_pi) / anchor_widths_pi
targets_dy = (gt_ctr_y - anchor_ctr_y_pi) / anchor_heights_pi
targets_dw = torch.log(gt_widths / anchor_widths_pi)
targets_dh = torch.log(gt_heights / anchor_heights_pi)
targets = torch.stack((targets_dx, targets_dy, targets_dw, targets_dh))
targets = targets.t()
targets = targets / torch.Tensor([[0.1, 0.1, 0.2, 0.2]]).cuda()
negative_indices = 1 - positive_indices
regression_diff = torch.abs(targets - regression[positive_indices, :])
regression_loss = torch.where(
torch.le(regression_diff, 1.0 / 9.0),
0.5 * 9.0 * torch.pow(regression_diff, 2),
regression_diff - 0.5 / 9.0
)
regression_losses.append(regression_loss.mean())
else:
regression_losses.append(torch.tensor(0).float().cuda())
return torch.stack(classification_losses).mean(dim=0, keepdim=True), torch.stack(regression_losses).mean(dim=0,
keepdim=True)