-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathroi_heads.py
112 lines (98 loc) · 3.47 KB
/
roi_heads.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from typing import Any, Dict, List, Optional, Sequence
import torch
from mmdet.core import BaseAssigner
from mmdet.models import StandardRoIHead
from mmdet.models.builder import HEADS
from .samplers import BBoxIDsMixin, SamplingResultWithBBoxIDs
@HEADS.register_module()
class StandardRoIHeadWithBBoxIDs(StandardRoIHead):
bbox_assigner: BaseAssigner
bbox_sampler: BBoxIDsMixin
def _sample(
self,
x: torch.Tensor,
proposal_list: List[torch.Tensor],
gt_bboxes: List[torch.Tensor],
gt_labels: List[torch.Tensor],
gt_bboxes_ignore: Sequence[Optional[torch.Tensor]],
):
sampling_results: List[SamplingResultWithBBoxIDs] = []
for i, (proposal, gt_bbox, gt_label, gt_bbox_ignore) in enumerate(
zip(proposal_list, gt_bboxes, gt_labels, gt_bboxes_ignore),
):
feats = [
lvl_feat[i][None] for lvl_feat in x
] # TODO: is this pythonic?
proposal_id = proposal[:, 5:]
proposal = proposal[:, :4]
assign_result = self.bbox_assigner.assign(
proposal,
gt_bbox,
gt_bbox_ignore,
gt_label,
)
sampling_result = self.bbox_sampler.sample(
assign_result,
proposal,
gt_bbox,
gt_label,
proposal_id,
feats=feats,
)
sampling_results.append(sampling_result)
bboxes = [s.bboxes for s in sampling_results]
bboxes_ids = torch.cat([s.bbox_ids for s in sampling_results])
bboxes_img_id = torch.cat([ # yapf: disable
torch.full_like(s.bbox_ids[:, [0]], i)
for i, s in enumerate(sampling_results)
])
bboxes_ids = torch.cat((bboxes_img_id, bboxes_ids), dim=-1)
self.bboxes = bboxes
self.bboxes_ids = bboxes_ids
return sampling_results
def init_weights(self) -> None:
# It is fine to have this function called after initialization, since
# `todd.distiller` may transfer weights.
if not self.is_init:
super().init_weights()
def forward_train(
self,
x: torch.Tensor,
img_metas: List[Dict[str, Any]],
proposal_list: List[torch.Tensor],
gt_bboxes: List[torch.Tensor],
gt_labels: List[torch.Tensor],
gt_bboxes_ignore: Optional[Sequence[Optional[torch.Tensor]]] = None,
gt_masks: Optional[List[Any]] = None,
**kwargs,
):
if gt_bboxes_ignore is None:
gt_bboxes_ignore = [None for _ in range(len(img_metas))]
losses = dict()
if self.with_bbox or self.with_mask:
sampling_results = self._sample(
x,
proposal_list,
gt_bboxes,
gt_labels,
gt_bboxes_ignore,
)
if self.with_bbox:
bbox_results = self._bbox_forward_train(
x,
sampling_results,
gt_bboxes,
gt_labels,
img_metas,
)
losses.update(bbox_results['loss_bbox'])
if self.with_mask:
mask_results = self._mask_forward_train(
x,
sampling_results,
bbox_results['bbox_feats'],
gt_masks,
img_metas,
)
losses.update(mask_results['loss_mask'])
return losses