From 49df2afb8afd208f13b4af343a981471166c0247 Mon Sep 17 00:00:00 2001 From: jhuni17 Date: Tue, 26 Nov 2024 16:40:38 +0900 Subject: [PATCH] feat: Add postprocessing script and add func in validate.py #38 --- smp_jh/utils/postprocessing.py | 88 ++++++++++++++++++++++++++++++++++ smp_jh/validate.py | 18 +++++-- 2 files changed, 103 insertions(+), 3 deletions(-) create mode 100644 smp_jh/utils/postprocessing.py diff --git a/smp_jh/utils/postprocessing.py b/smp_jh/utils/postprocessing.py new file mode 100644 index 0000000..bdddcef --- /dev/null +++ b/smp_jh/utils/postprocessing.py @@ -0,0 +1,88 @@ +import torch +import numpy as np +from scipy.ndimage import binary_dilation + +class AnatomicalPostProcessor: + def __init__(self, threshold=0.5): + self.threshold = threshold + self.finger_groups = [ + [0, 1, 2, 3], # 첫번째 손가락 + [4, 5, 6, 7], # 두번째 손가락 + [8, 9, 10, 11], # 세번째 손가락 + [12, 13, 14, 15], # 네번째 손가락 + [16, 17, 18] # 다섯번째 손가락 + ] + self.overlapping_pairs = [ + (19, 20), # Trapezium-Trapezoid + (25, 26) # Triquetrum-Pisiform + ] + + def __call__(self, prediction): + """ + prediction: (B, C, H, W) 형태의 모델 예측값 + C: 클래스 수 (29개: 19개 손가락 마디 + 8개 손목 뼈 + radius + ulna) + """ + return self.process(prediction) + + def process(self, prediction): + processed = prediction.clone() + batch_size = prediction.shape[0] + + for b in range(batch_size): + processed[b] = self._process_single_image(processed[b], prediction[b]) + + return processed + + def _process_single_image(self, processed, original): + # 1. 손가락 연속성 처리 + for finger in self.finger_groups: + for i in range(len(finger)-1): + curr_mask = processed[finger[i]] > self.threshold + next_mask = processed[finger[i+1]] > self.threshold + + if not self._check_connectivity(curr_mask, next_mask): + processed[finger[i]], processed[finger[i+1]] = \ + self._connect_segments(curr_mask, next_mask) + + # 2. 손목 뼈 겹침 처리 + for bone1, bone2 in self.overlapping_pairs: + mask1 = processed[bone1] > self.threshold + mask2 = processed[bone2] > self.threshold + + overlap = mask1 & mask2 + if overlap.any(): + processed[bone1][overlap] = \ + 1.0 if original[bone1][overlap].mean() > original[bone2][overlap].mean() else 0.0 + processed[bone2][overlap] = \ + 1.0 if original[bone2][overlap].mean() > original[bone1][overlap].mean() else 0.0 + + # 3. Radius-Ulna 관계 처리 + radius_mask = processed[-2] > self.threshold + ulna_mask = processed[-1] > self.threshold + processed[-2], processed[-1] = self._adjust_radius_ulna(radius_mask, ulna_mask) + + return processed + + @staticmethod + def _check_connectivity(mask1, mask2): + """두 마스크가 서로 연결되어 있는지 확인""" + dilated = binary_dilation(mask1.cpu().numpy(), iterations=2) + return np.any(dilated & mask2.cpu().numpy()) + + @staticmethod + def _connect_segments(mask1, mask2): + """두 분절을 연결""" + dilated1 = binary_dilation(mask1.cpu().numpy(), iterations=1) + dilated2 = binary_dilation(mask2.cpu().numpy(), iterations=1) + + connection = dilated1 & dilated2 + new_mask1 = mask1.cpu().numpy() | connection + new_mask2 = mask2.cpu().numpy() | connection + + return torch.from_numpy(new_mask1), torch.from_numpy(new_mask2) + + @staticmethod + def _adjust_radius_ulna(radius_mask, ulna_mask): + """Radius와 Ulna의 위치 관계 조정""" + # TODO: 구체적인 해부학적 규칙 구현 + return radius_mask, ulna_mask \ No newline at end of file diff --git a/smp_jh/validate.py b/smp_jh/validate.py index e4eeacc..3d16a27 100644 --- a/smp_jh/validate.py +++ b/smp_jh/validate.py @@ -17,6 +17,7 @@ from dataset.transforms import Transforms from utils.metrics import dice_coef from utils.rle import encode_mask_to_rle +from utils.postprocessing import AnatomicalPostProcessor def set_seed(seed): torch.manual_seed(seed) @@ -141,11 +142,14 @@ def __getitem__(self, item): return torch.from_numpy(image).float(), torch.from_numpy(label).float(), os.path.basename(image_path) -def validation(model, data_loader, device, threshold=0.5, save_gt=False): +def validation(model, data_loader, device, threshold=0.5, save_gt=False, use_postprocessing=False): """Validation 함수""" val_start = time.time() if model is not None: model.eval() + + # 후처리기 초기화 + postprocessor = AnatomicalPostProcessor(threshold=threshold) if use_postprocessing else None dices = [] pred_rles = [] @@ -162,6 +166,10 @@ def validation(model, data_loader, device, threshold=0.5, save_gt=False): # Forward pass outputs = model(images) + # Apply postprocessing if enabled + if use_postprocessing: + outputs = postprocessor(outputs) + # Resize for dice calculation output_h, output_w = outputs.size(-2), outputs.size(-1) mask_h, mask_w = masks.size(-2), masks.size(-1) @@ -264,7 +272,8 @@ def main(args): # Validation 실행 및 결과 저장 if args.save_gt: - pred_df = validation(model, valid_loader, device, args.threshold, save_gt=True) + pred_df = validation(model, valid_loader, device, args.threshold, save_gt=True, + use_postprocessing=args.postprocessing) gt_df = validation(None, valid_loader, device, args.threshold, save_gt=True) model_name = args.model_path.split('/')[-1] pred_df.to_csv(f"{model_name.split('.')[0]}_val.csv", index=False) @@ -272,7 +281,8 @@ def main(args): print(f"\nPrediction results saved to {model_name.split('.')[0]}_val.csv") print(f"Ground truth results saved to val_gt.csv") else: - pred_df = validation(model, valid_loader, device, args.threshold) + pred_df = validation(model, valid_loader, device, args.threshold, + use_postprocessing=args.postprocessing) model_name = args.model_path.split('/')[-1] pred_df.to_csv(f"{model_name.split('.')[0]}_val.csv", index=False) print(f"\nResults saved to {model_name.split('.')[0]}_val.csv") @@ -294,6 +304,8 @@ def main(args): help='Threshold for binary prediction') parser.add_argument('--save_gt', action='store_true', help='Save ground truth masks as separate CSV') + parser.add_argument('--postprocessing', action='store_true', + help='Enable anatomical postprocessing') args = parser.parse_args() main(args) \ No newline at end of file