-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference.py
165 lines (133 loc) · 5.44 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import os
import cv2
import argparse
import yaml
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
from tqdm import tqdm
import albumentations as A
from torch.utils.data import DataLoader
from src.dataset import XRayInferenceDataset
from utils.utils_for_visualizer import encode_mask_to_rle, decode_rle_to_mask
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(description='Inference segmentation model')
parser.add_argument('-c', '--config', type=str, default='smp_unetplusplus_efficientb0.yaml',
help='path to config file')
parser.add_argument('-m', '--model_path', type=str, default=None,
help='path to model checkpoint')
parser.add_argument('-o', '--output_path', type=str, default=None,
help='path to save prediction results')
parser.add_argument('--threshold', type=float, default=0.5,
help='threshold for binary prediction')
args = parser.parse_args()
# If model_path is not specified, use the config name
if args.model_path is None:
args.model_path = os.path.splitext(args.config)[0] + '.pt'
# If output_path is not specified, use the config name
if args.output_path is None:
args.output_path = os.path.splitext(args.config)[0] + '.csv'
# Add checkpoints directory to model path if not already specified
if not os.path.dirname(args.model_path):
config_name = os.path.splitext(args.config)[0]
args.model_path = os.path.join('checkpoints', config_name, args.model_path)
# Add results directory to output path if not already specified
results_dir = 'results'
os.makedirs(results_dir, exist_ok=True)
if not os.path.dirname(args.output_path):
args.output_path = os.path.join(results_dir, args.output_path)
return args
def load_config(config_name):
"""Load config file"""
config_path = os.path.join('configs', config_name)
if not os.path.exists(config_path):
print(f'Config file not found: {config_path}')
exit(1)
with open(config_path, 'r') as f:
try:
config = yaml.safe_load(f)
except yaml.YAMLError as e:
print(f'Error loading config file: {e}')
exit(1)
return config
class Inferencer:
def __init__(self, cfg, model_path, threshold=0.5):
self.cfg = cfg
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.threshold = threshold
# Load model
self.model = torch.load(model_path)
self.model = self.model.to(self.device)
self.model.eval()
# Setup transforms
self.transform = A.Compose([
A.Resize(cfg['DATASET']['IMAGE_SIZE'], cfg['DATASET']['IMAGE_SIZE']),
])
def predict(self, data_loader):
"""Run inference on the given data loader"""
rles = []
filename_and_class = []
with torch.no_grad():
for images, image_names in tqdm(data_loader, total=len(data_loader)):
images = images.to(self.device)
# Forward pass
outputs = self.model(images)
# Handle dictionary output - get the main prediction
if isinstance(outputs, dict):
outputs = outputs['out'] # or the appropriate key for your model's output
# Resize to original size
outputs = F.interpolate(outputs, size=(2048, 2048), mode="bilinear")
outputs = torch.sigmoid(outputs)
outputs = (outputs > self.threshold).detach().cpu().numpy()
# Convert to RLE format
for output, image_name in zip(outputs, image_names):
for c, segm in enumerate(output):
rle = encode_mask_to_rle(segm)
rles.append(rle)
filename_and_class.append(f"{self.cfg['CLASSES'][c]}_{image_name}")
return rles, filename_and_class
def main(args=None):
"""Main inference function"""
# Parse arguments and load config
if args is None:
args = parse_args()
cfg = load_config(args.config)
# Setup dataset and dataloader
test_dataset = XRayInferenceDataset(
cfg=cfg,
transforms=A.Compose([
A.Resize(cfg['DATASET']['IMAGE_SIZE'], cfg['DATASET']['IMAGE_SIZE']),
])
)
test_loader = DataLoader(
dataset=test_dataset,
batch_size=cfg['DATASET']['BATCH_SIZE'],
shuffle=False,
num_workers=cfg['DATASET']['NUM_WORKERS'],
drop_last=False
)
# Setup inferencer
inferencer = Inferencer(
cfg=cfg,
model_path=args.model_path,
threshold=args.threshold
)
# Run inference
print("Starting inference...")
rles, filename_and_class = inferencer.predict(test_loader)
# Prepare submission
classes, filename = zip(*[x.split("_") for x in filename_and_class])
image_name = [os.path.basename(f) for f in filename]
# Create submission dataframe
df = pd.DataFrame({
"image_name": image_name,
"class": classes,
"rle": rles,
})
# Save results
df.to_csv(args.output_path, index=False)
print(f"Results saved to {args.output_path}")
if __name__ == '__main__':
main()