-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhelper_metrics.py
56 lines (41 loc) · 2 KB
/
helper_metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import segmentation_models_pytorch as smp
def check_metrics(dataloader, model, device="cuda"):
accuracy_list = []
jaccard_list = []
dice_list = []
pix_acc_list = []
specificity_list = []
sensitivity_list = []
with torch.no_grad():
for image, mask in dataloader:
image = image.to(device).type(torch.float32)
mask = mask.to(device)
pred = model(image)
pred = (pred > 0.5).int()
tp, fp, fn, tn = smp.metrics.get_stats(pred, mask,
threshold=0.5,
mode='binary')
intersection = torch.logical_and(pred, mask).sum().detach()
union = torch.logical_or(pred, mask).sum().detach()
sensitivity = smp.metrics.functional.sensitivity(tp, fp, fn, tn, reduction="micro")
specificity = smp.metrics.functional.specificity(tp, fp, fn, tn, reduction="micro")
sensitivity_list.append(sensitivity)
specificity_list.append(specificity)
accuracy = smp.metrics.accuracy(tp, fp, fn, tn, reduction="micro")
accuracy_list.append(accuracy)
jaccard_score = intersection / union
jaccard_list.append(jaccard_score)
dice_score = (2.0 * intersection) / (pred.sum().detach() + mask.sum().detach())
dice_list.append(dice_score)
sensitivity = torch.mean(torch.stack(sensitivity_list)).item()
specificity = torch.mean(torch.stack(specificity_list)).item()
print(f"Sensitivity: {sensitivity:.6f}")
print(f"Specificity: {specificity:.6f}")
pix_acc = torch.mean(torch.stack(accuracy_list)).item()
print(f"Pixel Accuracy: {pix_acc:.6f}")
jaccard = torch.mean(torch.stack(jaccard_list)).item()
print(f"Jaccard Score: {jaccard:.6f}")
dice = torch.mean(torch.stack(dice_list)).item()
print(f"Dice Score: {dice:.6f}")
return sensitivity, specificity, pix_acc, jaccard, dice