Skip to content

Commit

Permalink
Merge pull request #65 from nez/master
Browse files Browse the repository at this point in the history
When cuda is available encode_class uses torch.cuda.LongTensor.
  • Loading branch information
muellerzr authored Mar 1, 2023
2 parents 08cda08 + 1e34c53 commit ecdd5ed
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions wwf/vision/object_detection/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,20 @@ def activ_to_bbox(acts, anchors, flatten=True):
def bbox_to_activ(bboxes, anchors, flatten=True):
"Return the target of the model on `anchors` for the `bboxes`."
if flatten:
t_centers = (bboxes[...,:2] - anchors[...,:2]) / anchors[...,2:]
t_sizes = torch.log(bboxes[...,2:] / anchors[...,2:] + 1e-8)
t_centers = (bboxes[...,:2] - anchors[...,:2]) / anchors[...,2:]
t_sizes = torch.log(bboxes[...,2:] / anchors[...,2:] + 1e-8)
return torch.cat([t_centers, t_sizes], -1).div_(bboxes.new_tensor([[0.1, 0.1, 0.2, 0.2]]))
else: return [activ_to_bbox(act,anc) for act,anc in zip(acts, anchors)]
return res

def encode_class(idxs, n_classes):
target = idxs.new_zeros(len(idxs), n_classes).float()
mask = idxs != 0
i1s = LongTensor(list(range(len(idxs))))
if cuda.is_available():
tensor_fn = torch.cuda.LongTensor
else:
tensor_fn = LongTensor
i1s = tensor_fn(list(range(len(idxs))))
target[i1s[mask],idxs[mask]-1] = 1
return target

Expand Down Expand Up @@ -83,7 +87,7 @@ def intersection(anchors, targets):
ancs, tgts = ancs.unsqueeze(1).expand(a,t,4), tgts.unsqueeze(0).expand(a,t,4)
top_left_i = torch.max(ancs[...,:2], tgts[...,:2])
bot_right_i = torch.min(ancs[...,2:], tgts[...,2:])
sizes = torch.clamp(bot_right_i - top_left_i, min=0)
sizes = torch.clamp(bot_right_i - top_left_i, min=0)
return sizes[...,0] * sizes[...,1]

def IoU_values(anchs, targs):
Expand All @@ -99,21 +103,21 @@ def __init__(self, gamma:float=2., alpha:float=0.25, pad_idx:int=0, scales=None
self.gamma,self.alpha,self.pad_idx,self.reg_loss = gamma,alpha,pad_idx,reg_loss
self.scales = ifnone(scales, [1,2**(-1/3), 2**(-2/3)])
self.ratios = ifnone(ratios, [1/2,1,2])

def _change_anchors(self, sizes) -> bool:
if not hasattr(self, 'sizes'): return True
for sz1, sz2 in zip(self.sizes, sizes):
if sz1[0] != sz2[0] or sz1[1] != sz2[1]: return True
return False

def _create_anchors(self, sizes, device:torch.device):
self.sizes = sizes
self.anchors = create_anchors(sizes, self.ratios, self.scales).to(device)

def _unpad(self, bbox_tgt, clas_tgt):
i = torch.min(torch.nonzero(clas_tgt-self.pad_idx))
return tlbr2cthw(bbox_tgt[i:]), clas_tgt[i:]-1+self.pad_idx

def _focal_loss(self, clas_pred, clas_tgt):
encoded_tgt = encode_class(clas_tgt, clas_pred.size(1))
ps = torch.sigmoid(clas_pred.detach())
Expand All @@ -122,7 +126,7 @@ def _focal_loss(self, clas_pred, clas_tgt):
weights.pow_(self.gamma).mul_(alphas)
clas_loss = F.binary_cross_entropy_with_logits(clas_pred, encoded_tgt, weights, reduction='sum')
return clas_loss

def _one_loss(self, clas_pred, bbox_pred, clas_tgt, bbox_tgt):
bbox_tgt, clas_tgt = self._unpad(bbox_tgt, clas_tgt)
matches = match_anchors(self.anchors, bbox_tgt)
Expand All @@ -139,7 +143,7 @@ def _one_loss(self, clas_pred, bbox_pred, clas_tgt, bbox_tgt):
clas_tgt = torch.cat([clas_tgt.new_zeros(1).long(), clas_tgt])
clas_tgt = clas_tgt[matches[clas_mask]]
return bb_loss + self._focal_loss(clas_pred, clas_tgt)/torch.clamp(bbox_mask.sum(), min=1.)

def forward(self, output, bbox_tgts, clas_tgts):
clas_preds, bbox_preds, sizes = output
if self._change_anchors(sizes): self._create_anchors(sizes, clas_preds.device)
Expand Down

0 comments on commit ecdd5ed

Please sign in to comment.