Skip to content

Commit

Permalink
feat: SMP Baseline code refactoring #23
Browse files Browse the repository at this point in the history
  • Loading branch information
jhuni17 committed Nov 18, 2024
1 parent fabc98e commit a047ab9
Show file tree
Hide file tree
Showing 8 changed files with 478 additions and 0 deletions.
31 changes: 31 additions & 0 deletions smp/config/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from pathlib import Path

class Config:
# Data
TRAIN_IMAGE_ROOT = "train/DCM"
TRAIN_LABEL_ROOT = "train/outputs_json"
TEST_IMAGE_ROOT = "test/DCM"

# Model
BATCH_SIZE = 8
LEARNING_RATE = 1e-4
NUM_EPOCHS = 5
VAL_EVERY = 5
RANDOM_SEED = 21

# Paths
SAVED_DIR = Path("checkpoints")
SAVED_DIR.mkdir(exist_ok=True)

# Classes
CLASSES = [
'finger-1', 'finger-2', 'finger-3', 'finger-4', 'finger-5',
'finger-6', 'finger-7', 'finger-8', 'finger-9', 'finger-10',
'finger-11', 'finger-12', 'finger-13', 'finger-14', 'finger-15',
'finger-16', 'finger-17', 'finger-18', 'finger-19', 'Trapezium',
'Trapezoid', 'Capitate', 'Hamate', 'Scaphoid', 'Lunate',
'Triquetrum', 'Pisiform', 'Radius', 'Ulna',
]

CLASS2IND = {v: i for i, v in enumerate(CLASSES)}
IND2CLASS = {v: k for k, v in CLASS2IND.items()}
161 changes: 161 additions & 0 deletions smp/dataset/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import os
import cv2
import json
import numpy as np
import torch
from torch.utils.data import Dataset
from sklearn.model_selection import GroupKFold

class XRayDataset(Dataset):
def __init__(self, image_root, label_root=None, is_train=True, transforms=None):
self.is_train = is_train
self.transforms = transforms

# Get all PNG files
self.image_root = image_root
self.label_root = label_root

self.pngs = self._get_pngs()
if is_train:
self.jsons = self._get_jsons()

# Verify matching between pngs and jsons
jsons_fn_prefix = {os.path.splitext(fname)[0] for fname in self.jsons}
pngs_fn_prefix = {os.path.splitext(fname)[0] for fname in self.pngs}

# Check if all files match
assert len(jsons_fn_prefix - pngs_fn_prefix) == 0, "Some JSON files don't have matching PNGs"
assert len(pngs_fn_prefix - jsons_fn_prefix) == 0, "Some PNG files don't have matching JSONs"

self.filenames, self.labelnames = self._split_dataset()
else:
self.filenames = sorted(self.pngs)

def _get_pngs(self):
return sorted([
os.path.relpath(os.path.join(root, fname), start=self.image_root)
for root, _dirs, files in os.walk(self.image_root)
for fname in files
if os.path.splitext(fname)[1].lower() == ".png"
])

def _get_jsons(self):
return sorted([
os.path.relpath(os.path.join(root, fname), start=self.label_root)
for root, _dirs, files in os.walk(self.label_root)
for fname in files
if os.path.splitext(fname)[1].lower() == ".json"
])

def _split_dataset(self):
_filenames = np.array(self.pngs)
_labelnames = np.array(self.jsons)

# Split train-valid using GroupKFold
groups = [os.path.dirname(fname) for fname in _filenames]

# dummy label
ys = [0 for _ in _filenames]

gkf = GroupKFold(n_splits=5)

filenames = []
labelnames = []

for i, (x, y) in enumerate(gkf.split(_filenames, ys, groups)):
if self.is_train:
if i == 0: # Use fold 0 as validation
continue

filenames += list(_filenames[y])
labelnames += list(_labelnames[y])
else:
filenames = list(_filenames[y])
labelnames = list(_labelnames[y])
break

return filenames, labelnames

def __len__(self):
return len(self.filenames)

def __getitem__(self, item):
image_name = self.filenames[item]
image_path = os.path.join(self.image_root, image_name)

image = cv2.imread(image_path)
image = image / 255.

if self.is_train:
label_name = self.labelnames[item]
label_path = os.path.join(self.label_root, label_name)

# Create label with shape (H, W, NC)
label_shape = tuple(image.shape[:2]) + (29,) # 29 classes
label = np.zeros(label_shape, dtype=np.uint8)

with open(label_path, "r") as f:
annotations = json.load(f)
annotations = annotations["annotations"]

# Process each class
for ann in annotations:
c = ann["label"]
class_ind = self.CLASS2IND[c]
points = np.array(ann["points"])

class_label = np.zeros(image.shape[:2], dtype=np.uint8)
cv2.fillPoly(class_label, [points], 1)
label[..., class_ind] = class_label

if self.transforms is not None:
inputs = {"image": image, "mask": label}
result = self.transforms(**inputs)
image = result["image"]
label = result["mask"]

# Convert to tensor format
image = image.transpose(2, 0, 1)
label = label.transpose(2, 0, 1)

return torch.from_numpy(image).float(), torch.from_numpy(label).float()
else:
if self.transforms is not None:
inputs = {"image": image}
result = self.transforms(**inputs)
image = result["image"]

image = image.transpose(2, 0, 1)
return torch.from_numpy(image).float(), image_name

class XRayInferenceDataset(Dataset):
def __init__(self, image_root, transforms=None):
self.image_root = image_root
self.transforms = transforms
self.filenames = self._get_pngs()

def _get_pngs(self):
return sorted([
os.path.relpath(os.path.join(root, fname), start=self.image_root)
for root, _dirs, files in os.walk(self.image_root)
for fname in files
if os.path.splitext(fname)[1].lower() == ".png"
])

def __len__(self):
return len(self.filenames)

def __getitem__(self, item):
image_name = self.filenames[item]
image_path = os.path.join(self.image_root, image_name)

image = cv2.imread(image_path)
image = image / 255.

if self.transforms is not None:
inputs = {"image": image}
result = self.transforms(**inputs)
image = result["image"]

image = image.transpose(2, 0, 1)
return torch.from_numpy(image).float(), image_name
21 changes: 21 additions & 0 deletions smp/dataset/transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import albumentations as A

class Transforms:
@staticmethod
def get_train_transform():
return A.Compose([
A.Resize(512, 512),
# TODO: Add more augmentations later
])

@staticmethod
def get_valid_transform():
return A.Compose([
A.Resize(512, 512),
])

@staticmethod
def get_test_transform():
return A.Compose([
A.Resize(512, 512),
])
75 changes: 75 additions & 0 deletions smp/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import os
import cv2
import torch
import torch.nn.functional as F
import pandas as pd
from tqdm.auto import tqdm
from torch.utils.data import DataLoader

from config.config import Config
from dataset.dataset import XRayDataset
from utils.rle import encode_mask_to_rle
from dataset.transforms import Transforms # Transforms 클래스 import

def test(model, data_loader, thr=0.5):
model = model.cuda()
model.eval()

rles = []
filename_and_class = []

with torch.no_grad():
for step, (images, image_names) in tqdm(enumerate(data_loader), total=len(data_loader)):
images = images.cuda()
outputs = model(images)['out']

# Resize to original size
outputs = F.interpolate(outputs, size=(2048, 2048), mode="bilinear")
outputs = torch.sigmoid(outputs)
outputs = (outputs > thr).detach().cpu().numpy()

for output, image_name in zip(outputs, image_names):
for c, segm in enumerate(output):
rle = encode_mask_to_rle(segm)
rles.append(rle)
filename_and_class.append(f"{Config.IND2CLASS[c]}_{image_name}")

return rles, filename_and_class

def main():
# 데이터셋 준비
test_dataset = XRayDataset(
image_root=Config.TEST_IMAGE_ROOT,
is_train=False,
transforms=Transforms.get_test_transform() # Transforms 클래스 사용
)

test_loader = DataLoader(
dataset=test_dataset,
batch_size=2,
shuffle=False,
num_workers=2,
drop_last=False
)

# 모델 로드
model = torch.load(os.path.join(Config.SAVED_DIR, "best_model.pt"))

# 추론
rles, filename_and_class = test(model, test_loader)

# 결과를 DataFrame으로 변환
classes, filename = zip(*[x.split("_") for x in filename_and_class])
image_name = [os.path.basename(f) for f in filename]

df = pd.DataFrame({
"image_name": image_name,
"class": classes,
"rle": rles,
})

# CSV 저장
df.to_csv("submission.csv", index=False)

if __name__ == "__main__":
main()
9 changes: 9 additions & 0 deletions smp/models/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import segmentation_models_pytorch as smp

def get_model(num_classes=29):
return smp.Unet(
encoder_name="efficientnet-b0",
encoder_weights="imagenet",
in_channels=3,
classes=num_classes,
)
Loading

0 comments on commit a047ab9

Please sign in to comment.