-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathinference.py
140 lines (116 loc) · 5 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
import pandas as pd
from tqdm import tqdm
import albumentations as A
import argparse
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from utils.method import encode_mask_to_rle
from utils.dataset import IND2CLASS, XRayInferenceDataset
import cv2
import numpy as np
def apply_cca(mask, min_size=500, max_components=3):
"""
Apply more aggressive Connected Component Analysis
Args:
mask: Binary mask
min_size: Minimum component size to keep
max_components: Maximum number of components to keep (keep largest ones)
Returns:
Cleaned mask
"""
# Get connected components
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask.astype(np.uint8), connectivity=8)
# Create cleaned mask
cleaned_mask = np.zeros_like(mask)
# Get all component sizes (excluding background)
sizes = [(label, stats[label, cv2.CC_STAT_AREA]) for label in range(1, num_labels)]
# Sort components by size (largest first)
sizes.sort(key=lambda x: x[1], reverse=True)
# Keep only components that meet the size threshold and respect max_components
count = 0
for label, size in sizes:
if size >= min_size and count < max_components:
cleaned_mask[labels == label] = 1
count += 1
return cleaned_mask
#CCA parameter는 min_component_size와 max_components 입니다.
#이상치(잘못 표시된 영역)의 사이즈가 2000픽셀이 넘어간다면 CCA가 완전히 적용 안될 수 있어, inference 결과는 자주 시각화해서 확인해봅시다!
def test(model, data_loader, thr=0.5, min_component_size=2000, max_components=1):
model = model.cuda()
model.eval()
rles = []
filename_and_class = []
with torch.no_grad():
for step, (images, image_names) in tqdm(enumerate(data_loader), total=len(data_loader)):
images = images.cuda()
try:
outputs = model(images)['out']
except:
outputs = model(images)
outputs = F.interpolate(outputs, size=(2048, 2048), mode="bilinear")
outputs = torch.sigmoid(outputs)
outputs = (outputs > thr).detach().cpu().numpy()
for output, image_name in zip(outputs, image_names):
for c, segm in enumerate(output):
# Apply more aggressive CCA
cleaned_segm = apply_cca(segm, min_size=min_component_size, max_components=max_components)
rle = encode_mask_to_rle(cleaned_segm)
rles.append(rle)
filename_and_class.append(f"{IND2CLASS[c]}_{image_name}")
return rles, filename_and_class
def parse_args():
parser = argparse.ArgumentParser(description='X-Ray 이미지 세그멘테이션 추론')
parser.add_argument('--image_root', type=str, default='./data/test/DCM',
help='테스트 이미지가 있는 디렉토리 경로')
parser.add_argument('--model_path', type=str, default='./checkpoints/fcn_resnet50.pt',
help='학습된 모델 파일 경로')
parser.add_argument('--batch_size', type=int, default=2,
help='배치 크기')
parser.add_argument('--threshold', type=float, default=0.5,
help='세그멘테이션 임계값')
parser.add_argument('--output_path', type=str, default='output.csv',
help='결과 저장할 CSV 파일 경로')
parser.add_argument('--img_size', type=int, default=512,
help='입력 이미지 크기')
parser.add_argument('--min_component_size', type=int, default=2000,
help='Minimum size for connected components')
parser.add_argument('--max_components', type=int, default=1,
help='Maximum number of components to keep')
return parser.parse_args()
def main():
args = parse_args()
# 모델 로드
model = torch.load(args.model_path)
# 데이터셋 및 데이터로더 설정
tf = A.Compose([
A.Resize(args.img_size, args.img_size),
])
test_dataset = XRayInferenceDataset(args.image_root, transforms=tf)
test_loader = DataLoader(
dataset=test_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=8,
drop_last=False
)
# 추론 수행
rles, filename_and_class = test(
model,
test_loader,
thr=args.threshold,
min_component_size=args.min_component_size,
max_components=args.max_components
)
# submission 파일 생성
classes, filename = zip(*[x.split("_") for x in filename_and_class])
image_name = [os.path.basename(f) for f in filename]
df = pd.DataFrame({
"image_name": image_name,
"class": classes,
"rle": rles,
})
df.to_csv(args.output_path, index=False)
if __name__ == '__main__':
main()