-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluator.py
executable file
·52 lines (42 loc) · 2.58 KB
/
evaluator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
from typing import Tuple
from torch.utils.data import DataLoader
from tqdm import tqdm
from model import Model
class Evaluator(object):
def __init__(self, dataset, batch_size, path_to_data_dir, path_to_results_dir, device, kwargs):
super().__init__()
self._dataset = dataset
self._dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, **kwargs)
self._path_to_data_dir = path_to_data_dir
self._path_to_results_dir = path_to_results_dir
self._device = device
def evaluate(self, model: Model) -> Tuple[float, str]:
all_image_ids, all_detection_bboxes, all_detection_classes, all_detection_probs = [], [], [], []
with torch.no_grad():
for _, (image_id_batch, image_batch, scale_batch, _, _) in enumerate(tqdm(self._dataloader)):
#image_batch = image_batch.cuda()
#assert image_batch.shape[0] == 1, 'do not use batch size more than 1 on evaluation'
image_batch = image_batch.to(self._device)
detection_bboxes, \
detection_classes, \
detection_probs, \
detection_batch_indices = model.eval().forward(image_batch)
scale_batch = scale_batch[detection_batch_indices].unsqueeze(dim=-1).expand_as(detection_bboxes).to(device=detection_bboxes.device)
detection_bboxes = detection_bboxes / scale_batch
#kept_indices = (detection_probs > 0.05).nonzero().view(-1)
kept_indices = torch.nonzero(detection_probs > 0.05).view(-1)
detection_bboxes = detection_bboxes[kept_indices]
detection_classes = detection_classes[kept_indices]
detection_probs = detection_probs[kept_indices]
detection_batch_indices = detection_batch_indices[kept_indices]
all_detection_bboxes.extend(detection_bboxes.tolist())
all_detection_classes.extend(detection_classes.tolist())
all_detection_probs.extend(detection_probs.tolist())
all_image_ids.extend([image_id_batch[i] for i in detection_batch_indices])
mean_ap, detail = self._dataset.evaluate(self._path_to_results_dir,
all_image_ids,
all_detection_bboxes,
all_detection_classes,
all_detection_probs)
return mean_ap, detail