diff --git a/wwf/vision/object_detection/metrics.py b/wwf/vision/object_detection/metrics.py index 24ac09b4..a1cf14b9 100644 --- a/wwf/vision/object_detection/metrics.py +++ b/wwf/vision/object_detection/metrics.py @@ -17,8 +17,8 @@ def activ_to_bbox(acts, anchors, flatten=True): def bbox_to_activ(bboxes, anchors, flatten=True): "Return the target of the model on `anchors` for the `bboxes`." if flatten: - t_centers = (bboxes[...,:2] - anchors[...,:2]) / anchors[...,2:] - t_sizes = torch.log(bboxes[...,2:] / anchors[...,2:] + 1e-8) + t_centers = (bboxes[...,:2] - anchors[...,:2]) / anchors[...,2:] + t_sizes = torch.log(bboxes[...,2:] / anchors[...,2:] + 1e-8) return torch.cat([t_centers, t_sizes], -1).div_(bboxes.new_tensor([[0.1, 0.1, 0.2, 0.2]])) else: return [activ_to_bbox(act,anc) for act,anc in zip(acts, anchors)] return res @@ -26,7 +26,11 @@ def bbox_to_activ(bboxes, anchors, flatten=True): def encode_class(idxs, n_classes): target = idxs.new_zeros(len(idxs), n_classes).float() mask = idxs != 0 - i1s = LongTensor(list(range(len(idxs)))) + if cuda.is_available(): + tensor_fn = torch.cuda.LongTensor + else: + tensor_fn = LongTensor + i1s = tensor_fn(list(range(len(idxs)))) target[i1s[mask],idxs[mask]-1] = 1 return target @@ -83,7 +87,7 @@ def intersection(anchors, targets): ancs, tgts = ancs.unsqueeze(1).expand(a,t,4), tgts.unsqueeze(0).expand(a,t,4) top_left_i = torch.max(ancs[...,:2], tgts[...,:2]) bot_right_i = torch.min(ancs[...,2:], tgts[...,2:]) - sizes = torch.clamp(bot_right_i - top_left_i, min=0) + sizes = torch.clamp(bot_right_i - top_left_i, min=0) return sizes[...,0] * sizes[...,1] def IoU_values(anchs, targs): @@ -99,21 +103,21 @@ def __init__(self, gamma:float=2., alpha:float=0.25, pad_idx:int=0, scales=None self.gamma,self.alpha,self.pad_idx,self.reg_loss = gamma,alpha,pad_idx,reg_loss self.scales = ifnone(scales, [1,2**(-1/3), 2**(-2/3)]) self.ratios = ifnone(ratios, [1/2,1,2]) - + def _change_anchors(self, sizes) -> bool: if not hasattr(self, 'sizes'): return True for sz1, sz2 in zip(self.sizes, sizes): if sz1[0] != sz2[0] or sz1[1] != sz2[1]: return True return False - + def _create_anchors(self, sizes, device:torch.device): self.sizes = sizes self.anchors = create_anchors(sizes, self.ratios, self.scales).to(device) - + def _unpad(self, bbox_tgt, clas_tgt): i = torch.min(torch.nonzero(clas_tgt-self.pad_idx)) return tlbr2cthw(bbox_tgt[i:]), clas_tgt[i:]-1+self.pad_idx - + def _focal_loss(self, clas_pred, clas_tgt): encoded_tgt = encode_class(clas_tgt, clas_pred.size(1)) ps = torch.sigmoid(clas_pred.detach()) @@ -122,7 +126,7 @@ def _focal_loss(self, clas_pred, clas_tgt): weights.pow_(self.gamma).mul_(alphas) clas_loss = F.binary_cross_entropy_with_logits(clas_pred, encoded_tgt, weights, reduction='sum') return clas_loss - + def _one_loss(self, clas_pred, bbox_pred, clas_tgt, bbox_tgt): bbox_tgt, clas_tgt = self._unpad(bbox_tgt, clas_tgt) matches = match_anchors(self.anchors, bbox_tgt) @@ -139,7 +143,7 @@ def _one_loss(self, clas_pred, bbox_pred, clas_tgt, bbox_tgt): clas_tgt = torch.cat([clas_tgt.new_zeros(1).long(), clas_tgt]) clas_tgt = clas_tgt[matches[clas_mask]] return bb_loss + self._focal_loss(clas_pred, clas_tgt)/torch.clamp(bbox_mask.sum(), min=1.) - + def forward(self, output, bbox_tgts, clas_tgts): clas_preds, bbox_preds, sizes = output if self._change_anchors(sizes): self._create_anchors(sizes, clas_preds.device)