-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdatasets.py
50 lines (40 loc) · 1.6 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from mmdet.datasets import DATASETS
from mmdet.datasets import CocoDataset as _CocoDataset
from mmdet.datasets import CustomDataset
from head.utils import Debug
class DebugMixin(CustomDataset):
def __len__(self) -> int:
if Debug.LESS_DATA:
return 4
return super().__len__()
def load_annotations(self, *args, **kwargs):
data_infos = super().load_annotations(*args, **kwargs)
if Debug.LESS_DATA:
data_infos = data_infos[:len(self)]
return data_infos
def load_proposals(self, *args, **kwargs):
proposals = super().load_proposals(*args, **kwargs)
if Debug.LESS_DATA:
proposals = proposals[:len(self)]
return proposals
def evaluate(self, *args, **kwargs):
kwargs.pop('gpu_collect', None)
kwargs.pop('tmpdir', None)
return super().evaluate(*args, **kwargs)
@DATASETS.register_module(force=True)
class CocoDataset(DebugMixin, _CocoDataset):
def load_annotations(self, *args, **kwargs):
data_infos = super().load_annotations(*args, **kwargs)
if Debug.LESS_DATA:
self.coco.dataset['images'] = \
self.coco.dataset['images'][:len(self)]
self.img_ids = [img['id'] for img in self.coco.dataset['images']]
self.coco.dataset['annotations'] = [
ann for ann in self.coco.dataset['annotations']
if ann['image_id'] in self.img_ids
]
self.coco.imgs = {
img['id']: img
for img in self.coco.dataset['images']
}
return data_infos