From a047ab967d69fa0d09ed71fba1393bc55c0ce9e0 Mon Sep 17 00:00:00 2001 From: jhuni17 Date: Mon, 18 Nov 2024 11:22:46 +0900 Subject: [PATCH] feat: SMP Baseline code refactoring #23 --- smp/config/config.py | 31 ++++++++ smp/dataset/dataset.py | 161 ++++++++++++++++++++++++++++++++++++++ smp/dataset/transforms.py | 21 +++++ smp/inference.py | 75 ++++++++++++++++++ smp/models/model.py | 9 +++ smp/train.py | 150 +++++++++++++++++++++++++++++++++++ smp/utils/metrics.py | 11 +++ smp/utils/rle.py | 20 +++++ 8 files changed, 478 insertions(+) create mode 100644 smp/config/config.py create mode 100644 smp/dataset/dataset.py create mode 100644 smp/dataset/transforms.py create mode 100644 smp/inference.py create mode 100644 smp/models/model.py create mode 100644 smp/train.py create mode 100644 smp/utils/metrics.py create mode 100644 smp/utils/rle.py diff --git a/smp/config/config.py b/smp/config/config.py new file mode 100644 index 0000000..ada6c4b --- /dev/null +++ b/smp/config/config.py @@ -0,0 +1,31 @@ +from pathlib import Path + +class Config: + # Data + TRAIN_IMAGE_ROOT = "train/DCM" + TRAIN_LABEL_ROOT = "train/outputs_json" + TEST_IMAGE_ROOT = "test/DCM" + + # Model + BATCH_SIZE = 8 + LEARNING_RATE = 1e-4 + NUM_EPOCHS = 5 + VAL_EVERY = 5 + RANDOM_SEED = 21 + + # Paths + SAVED_DIR = Path("checkpoints") + SAVED_DIR.mkdir(exist_ok=True) + + # Classes + CLASSES = [ + 'finger-1', 'finger-2', 'finger-3', 'finger-4', 'finger-5', + 'finger-6', 'finger-7', 'finger-8', 'finger-9', 'finger-10', + 'finger-11', 'finger-12', 'finger-13', 'finger-14', 'finger-15', + 'finger-16', 'finger-17', 'finger-18', 'finger-19', 'Trapezium', + 'Trapezoid', 'Capitate', 'Hamate', 'Scaphoid', 'Lunate', + 'Triquetrum', 'Pisiform', 'Radius', 'Ulna', + ] + + CLASS2IND = {v: i for i, v in enumerate(CLASSES)} + IND2CLASS = {v: k for k, v in CLASS2IND.items()} \ No newline at end of file diff --git a/smp/dataset/dataset.py b/smp/dataset/dataset.py new file mode 100644 index 0000000..bfcbc6f --- /dev/null +++ b/smp/dataset/dataset.py @@ -0,0 +1,161 @@ +import os +import cv2 +import json +import numpy as np +import torch +from torch.utils.data import Dataset +from sklearn.model_selection import GroupKFold + +class XRayDataset(Dataset): + def __init__(self, image_root, label_root=None, is_train=True, transforms=None): + self.is_train = is_train + self.transforms = transforms + + # Get all PNG files + self.image_root = image_root + self.label_root = label_root + + self.pngs = self._get_pngs() + if is_train: + self.jsons = self._get_jsons() + + # Verify matching between pngs and jsons + jsons_fn_prefix = {os.path.splitext(fname)[0] for fname in self.jsons} + pngs_fn_prefix = {os.path.splitext(fname)[0] for fname in self.pngs} + + # Check if all files match + assert len(jsons_fn_prefix - pngs_fn_prefix) == 0, "Some JSON files don't have matching PNGs" + assert len(pngs_fn_prefix - jsons_fn_prefix) == 0, "Some PNG files don't have matching JSONs" + + self.filenames, self.labelnames = self._split_dataset() + else: + self.filenames = sorted(self.pngs) + + def _get_pngs(self): + return sorted([ + os.path.relpath(os.path.join(root, fname), start=self.image_root) + for root, _dirs, files in os.walk(self.image_root) + for fname in files + if os.path.splitext(fname)[1].lower() == ".png" + ]) + + def _get_jsons(self): + return sorted([ + os.path.relpath(os.path.join(root, fname), start=self.label_root) + for root, _dirs, files in os.walk(self.label_root) + for fname in files + if os.path.splitext(fname)[1].lower() == ".json" + ]) + + def _split_dataset(self): + _filenames = np.array(self.pngs) + _labelnames = np.array(self.jsons) + + # Split train-valid using GroupKFold + groups = [os.path.dirname(fname) for fname in _filenames] + + # dummy label + ys = [0 for _ in _filenames] + + gkf = GroupKFold(n_splits=5) + + filenames = [] + labelnames = [] + + for i, (x, y) in enumerate(gkf.split(_filenames, ys, groups)): + if self.is_train: + if i == 0: # Use fold 0 as validation + continue + + filenames += list(_filenames[y]) + labelnames += list(_labelnames[y]) + else: + filenames = list(_filenames[y]) + labelnames = list(_labelnames[y]) + break + + return filenames, labelnames + + def __len__(self): + return len(self.filenames) + + def __getitem__(self, item): + image_name = self.filenames[item] + image_path = os.path.join(self.image_root, image_name) + + image = cv2.imread(image_path) + image = image / 255. + + if self.is_train: + label_name = self.labelnames[item] + label_path = os.path.join(self.label_root, label_name) + + # Create label with shape (H, W, NC) + label_shape = tuple(image.shape[:2]) + (29,) # 29 classes + label = np.zeros(label_shape, dtype=np.uint8) + + with open(label_path, "r") as f: + annotations = json.load(f) + annotations = annotations["annotations"] + + # Process each class + for ann in annotations: + c = ann["label"] + class_ind = self.CLASS2IND[c] + points = np.array(ann["points"]) + + class_label = np.zeros(image.shape[:2], dtype=np.uint8) + cv2.fillPoly(class_label, [points], 1) + label[..., class_ind] = class_label + + if self.transforms is not None: + inputs = {"image": image, "mask": label} + result = self.transforms(**inputs) + image = result["image"] + label = result["mask"] + + # Convert to tensor format + image = image.transpose(2, 0, 1) + label = label.transpose(2, 0, 1) + + return torch.from_numpy(image).float(), torch.from_numpy(label).float() + else: + if self.transforms is not None: + inputs = {"image": image} + result = self.transforms(**inputs) + image = result["image"] + + image = image.transpose(2, 0, 1) + return torch.from_numpy(image).float(), image_name + +class XRayInferenceDataset(Dataset): + def __init__(self, image_root, transforms=None): + self.image_root = image_root + self.transforms = transforms + self.filenames = self._get_pngs() + + def _get_pngs(self): + return sorted([ + os.path.relpath(os.path.join(root, fname), start=self.image_root) + for root, _dirs, files in os.walk(self.image_root) + for fname in files + if os.path.splitext(fname)[1].lower() == ".png" + ]) + + def __len__(self): + return len(self.filenames) + + def __getitem__(self, item): + image_name = self.filenames[item] + image_path = os.path.join(self.image_root, image_name) + + image = cv2.imread(image_path) + image = image / 255. + + if self.transforms is not None: + inputs = {"image": image} + result = self.transforms(**inputs) + image = result["image"] + + image = image.transpose(2, 0, 1) + return torch.from_numpy(image).float(), image_name \ No newline at end of file diff --git a/smp/dataset/transforms.py b/smp/dataset/transforms.py new file mode 100644 index 0000000..88dd366 --- /dev/null +++ b/smp/dataset/transforms.py @@ -0,0 +1,21 @@ +import albumentations as A + +class Transforms: + @staticmethod + def get_train_transform(): + return A.Compose([ + A.Resize(512, 512), + # TODO: Add more augmentations later + ]) + + @staticmethod + def get_valid_transform(): + return A.Compose([ + A.Resize(512, 512), + ]) + + @staticmethod + def get_test_transform(): + return A.Compose([ + A.Resize(512, 512), + ]) \ No newline at end of file diff --git a/smp/inference.py b/smp/inference.py new file mode 100644 index 0000000..fe96a32 --- /dev/null +++ b/smp/inference.py @@ -0,0 +1,75 @@ +import os +import cv2 +import torch +import torch.nn.functional as F +import pandas as pd +from tqdm.auto import tqdm +from torch.utils.data import DataLoader + +from config.config import Config +from dataset.dataset import XRayDataset +from utils.rle import encode_mask_to_rle +from dataset.transforms import Transforms # Transforms 클래스 import + +def test(model, data_loader, thr=0.5): + model = model.cuda() + model.eval() + + rles = [] + filename_and_class = [] + + with torch.no_grad(): + for step, (images, image_names) in tqdm(enumerate(data_loader), total=len(data_loader)): + images = images.cuda() + outputs = model(images)['out'] + + # Resize to original size + outputs = F.interpolate(outputs, size=(2048, 2048), mode="bilinear") + outputs = torch.sigmoid(outputs) + outputs = (outputs > thr).detach().cpu().numpy() + + for output, image_name in zip(outputs, image_names): + for c, segm in enumerate(output): + rle = encode_mask_to_rle(segm) + rles.append(rle) + filename_and_class.append(f"{Config.IND2CLASS[c]}_{image_name}") + + return rles, filename_and_class + +def main(): + # 데이터셋 준비 + test_dataset = XRayDataset( + image_root=Config.TEST_IMAGE_ROOT, + is_train=False, + transforms=Transforms.get_test_transform() # Transforms 클래스 사용 + ) + + test_loader = DataLoader( + dataset=test_dataset, + batch_size=2, + shuffle=False, + num_workers=2, + drop_last=False + ) + + # 모델 로드 + model = torch.load(os.path.join(Config.SAVED_DIR, "best_model.pt")) + + # 추론 + rles, filename_and_class = test(model, test_loader) + + # 결과를 DataFrame으로 변환 + classes, filename = zip(*[x.split("_") for x in filename_and_class]) + image_name = [os.path.basename(f) for f in filename] + + df = pd.DataFrame({ + "image_name": image_name, + "class": classes, + "rle": rles, + }) + + # CSV 저장 + df.to_csv("submission.csv", index=False) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/smp/models/model.py b/smp/models/model.py new file mode 100644 index 0000000..c116ee1 --- /dev/null +++ b/smp/models/model.py @@ -0,0 +1,9 @@ +import segmentation_models_pytorch as smp + +def get_model(num_classes=29): + return smp.Unet( + encoder_name="efficientnet-b0", + encoder_weights="imagenet", + in_channels=3, + classes=num_classes, + ) \ No newline at end of file diff --git a/smp/train.py b/smp/train.py new file mode 100644 index 0000000..056733e --- /dev/null +++ b/smp/train.py @@ -0,0 +1,150 @@ +import os +import datetime +import random +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +from tqdm.auto import tqdm + +from config.config import Config +from dataset.dataset import XRayDataset +from models.model import get_model +from utils.metrics import dice_coef +from dataset.transforms import Transforms + +def set_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + np.random.seed(seed) + random.seed(seed) + +def validation(epoch, model, data_loader, criterion, thr=0.5): + print(f'Start validation #{epoch:2d}') + model.eval() + + dices = [] + with torch.no_grad(): + total_loss = 0 + cnt = 0 + + for step, (images, masks) in tqdm(enumerate(data_loader), total=len(data_loader)): + images, masks = images.cuda(), masks.cuda() + model = model.cuda() + + outputs = model(images)['out'] + + # Resize outputs if needed + output_h, output_w = outputs.size(-2), outputs.size(-1) + mask_h, mask_w = masks.size(-2), masks.size(-1) + if output_h != mask_h or output_w != mask_w: + outputs = F.interpolate(outputs, size=(mask_h, mask_w), mode="bilinear") + + loss = criterion(outputs, masks) + total_loss += loss + cnt += 1 + + outputs = torch.sigmoid(outputs) + outputs = (outputs > thr).detach().cpu() + masks = masks.detach().cpu() + + dice = dice_coef(outputs, masks) + dices.append(dice) + + dices = torch.cat(dices, 0) + dices_per_class = torch.mean(dices, 0) + + # Print dice scores for each class + dice_str = [ + f"{c:<12}: {d.item():.4f}" + for c, d in zip(Config.CLASSES, dices_per_class) + ] + print("\n".join(dice_str)) + + avg_dice = torch.mean(dices_per_class).item() + return avg_dice + +def train(): + set_seed(Config.RANDOM_SEED) + + # 모델 준비 + model = get_model(num_classes=len(Config.CLASSES)) + criterion = nn.BCEWithLogitsLoss() + optimizer = optim.Adam( + params=model.parameters(), + lr=Config.LEARNING_RATE, + weight_decay=1e-6 + ) + + # 데이터셋 준비 + train_dataset = XRayDataset( + image_root=Config.TRAIN_IMAGE_ROOT, + label_root=Config.TRAIN_LABEL_ROOT, + is_train=True, + transforms=Transforms.get_train_transform() + ) + + valid_dataset = XRayDataset( + image_root=Config.TRAIN_IMAGE_ROOT, + label_root=Config.TRAIN_LABEL_ROOT, + is_train=False, + transforms=Transforms.get_valid_transform() + ) + + # DataLoader + train_loader = DataLoader( + dataset=train_dataset, + batch_size=Config.BATCH_SIZE, + shuffle=True, + num_workers=8, + drop_last=True, + ) + + valid_loader = DataLoader( + dataset=valid_dataset, + batch_size=Config.BATCH_SIZE, + shuffle=False, + num_workers=4, + drop_last=False + ) + + # Training loop + best_dice = 0. + for epoch in range(Config.NUM_EPOCHS): + model.train() + + for step, (images, masks) in enumerate(train_loader): + images, masks = images.cuda(), masks.cuda() + model = model.cuda() + + outputs = model(images)['out'] + loss = criterion(outputs, masks) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if (step + 1) % 25 == 0: + print( + f'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")} | ' + f'Epoch [{epoch+1}/{Config.NUM_EPOCHS}], ' + f'Step [{step+1}/{len(train_loader)}], ' + f'Loss: {round(loss.item(),4)}' + ) + + if (epoch + 1) % Config.VAL_EVERY == 0: + dice = validation(epoch + 1, model, valid_loader, criterion) + + if best_dice < dice: + print(f"Best performance at epoch: {epoch + 1}, {best_dice:.4f} -> {dice:.4f}") + print(f"Save model in {Config.SAVED_DIR}") + best_dice = dice + torch.save(model, os.path.join(Config.SAVED_DIR, "best_model.pt")) + +if __name__ == "__main__": + train() \ No newline at end of file diff --git a/smp/utils/metrics.py b/smp/utils/metrics.py new file mode 100644 index 0000000..001ac05 --- /dev/null +++ b/smp/utils/metrics.py @@ -0,0 +1,11 @@ +import torch + +def dice_coef(y_true, y_pred): + y_true_f = y_true.flatten(2) + y_pred_f = y_pred.flatten(2) + intersection = torch.sum(y_true_f * y_pred_f, -1) + + eps = 0.0001 + return (2. * intersection + eps) / ( + torch.sum(y_true_f, -1) + torch.sum(y_pred_f, -1) + eps + ) \ No newline at end of file diff --git a/smp/utils/rle.py b/smp/utils/rle.py new file mode 100644 index 0000000..52fcb09 --- /dev/null +++ b/smp/utils/rle.py @@ -0,0 +1,20 @@ +import numpy as np + +def encode_mask_to_rle(mask): + pixels = mask.flatten() + pixels = np.concatenate([[0], pixels, [0]]) + runs = np.where(pixels[1:] != pixels[:-1])[0] + 1 + runs[1::2] -= runs[::2] + return ' '.join(str(x) for x in runs) + +def decode_rle_to_mask(rle, height, width): + s = rle.split() + starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])] + starts -= 1 + ends = starts + lengths + img = np.zeros(height * width, dtype=np.uint8) + + for lo, hi in zip(starts, ends): + img[lo:hi] = 1 + + return img.reshape(height, width) \ No newline at end of file