From 8f7630bfac8e4e70ca497e96a913867cb0c58b1a Mon Sep 17 00:00:00 2001 From: Paris Zhang Date: Mon, 20 Feb 2023 05:57:16 +0000 Subject: [PATCH 1/2] add code to finetune 4 models --- tune_cls.py | 766 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 766 insertions(+) create mode 100644 tune_cls.py diff --git a/tune_cls.py b/tune_cls.py new file mode 100644 index 0000000..2d1a186 --- /dev/null +++ b/tune_cls.py @@ -0,0 +1,766 @@ +import os +os.environ["MKL_NUM_THREADS"] = "2" +os.environ["NUMEXPR_NUM_THREADS"] = "2" +os.environ["OMP_NUM_THREADS"] = "2" + +from os import path, makedirs, listdir +import sys +import numpy as np +np.random.seed(1) +import random +random.seed(1) + +import torch +from torch import nn +from torch.backends import cudnn +from torch.utils.data import Dataset +from torch.utils.data import DataLoader +import torch.optim.lr_scheduler as lr_scheduler + +from apex import amp + +from adamw import AdamW +from losses import dice_round, ComboLoss + +import pandas as pd +from tqdm import tqdm +import timeit +import cv2 + +from zoo.models import Res34_Unet_Double +from zoo.models import SeResNext50_Unet_Double +from zoo.models import Dpn92_Unet_Double +from zoo.models import SeNet154_Unet_Double + +from imgaug import augmenters as iaa + +from utils import * + +from skimage.morphology import square, dilation + +from sklearn.model_selection import train_test_split + +from sklearn.metrics import accuracy_score + +import gc + +cv2.setNumThreads(0) +cv2.ocl.setUseOpenCL(False) + +train_dirs = ['/shared/ritwik/data/xview2/train'] + +models_folder = 'weights' + +loc_folder = 'pred_loc_val' + +input_shape = (512, 512) + +all_files = [] +for d in train_dirs: + for f in sorted(listdir(path.join(d, 'images'))): + if '_pre_disaster.png' in f: + all_files.append(path.join(d, 'images', f)) +train_len = len(all_files) + + +class TrainData(Dataset): + def __init__(self, train_idxs): + super().__init__() + self.train_idxs = train_idxs + self.elastic = iaa.ElasticTransformation(alpha=(0.25, 1.2), sigma=0.2) + + def __len__(self): + return len(self.train_idxs) + + def __getitem__(self, idx): + _idx = self.train_idxs[idx] + + fn = all_files[_idx] + + img = cv2.imread(fn, cv2.IMREAD_COLOR) + img2 = cv2.imread(fn.replace('_pre_', '_post_'), cv2.IMREAD_COLOR) + + msk0 = cv2.imread(fn.replace('/shared/ritwik/data/xview2/train/images/', 'masks/'), cv2.IMREAD_UNCHANGED) + lbl_msk1 = cv2.imread(fn.replace('/shared/ritwik/data/xview2/train/images/', 'masks/').replace('_pre_disaster', '_post_disaster'), cv2.IMREAD_UNCHANGED) + + msk1 = np.zeros_like(lbl_msk1) + msk2 = np.zeros_like(lbl_msk1) + msk3 = np.zeros_like(lbl_msk1) + msk4 = np.zeros_like(lbl_msk1) + msk2[lbl_msk1 == 2] = 255 + msk3[lbl_msk1 == 3] = 255 + msk4[lbl_msk1 == 4] = 255 + msk1[lbl_msk1 == 1] = 255 + + if random.random() > 0.7: + img = img[::-1, ...] + img2 = img2[::-1, ...] + msk0 = msk0[::-1, ...] + msk1 = msk1[::-1, ...] + msk2 = msk2[::-1, ...] + msk3 = msk3[::-1, ...] + msk4 = msk4[::-1, ...] + + if random.random() > 0.3: + rot = random.randrange(4) + if rot > 0: + img = np.rot90(img, k=rot) + img2 = np.rot90(img2, k=rot) + msk0 = np.rot90(msk0, k=rot) + msk1 = np.rot90(msk1, k=rot) + msk2 = np.rot90(msk2, k=rot) + msk3 = np.rot90(msk3, k=rot) + msk4 = np.rot90(msk4, k=rot) + + if random.random() > 0.99: + shift_pnt = (random.randint(-320, 320), random.randint(-320, 320)) + img = shift_image(img, shift_pnt) + img2 = shift_image(img2, shift_pnt) + msk0 = shift_image(msk0, shift_pnt) + msk1 = shift_image(msk1, shift_pnt) + msk2 = shift_image(msk2, shift_pnt) + msk3 = shift_image(msk3, shift_pnt) + msk4 = shift_image(msk4, shift_pnt) + + if random.random() > 0.5: + rot_pnt = (img.shape[0] // 2 + random.randint(-320, 320), img.shape[1] // 2 + random.randint(-320, 320)) + scale = 0.9 + random.random() * 0.2 + angle = random.randint(0, 20) - 10 + if (angle != 0) or (scale != 1): + img = rotate_image(img, angle, scale, rot_pnt) + img2 = rotate_image(img2, angle, scale, rot_pnt) + msk0 = rotate_image(msk0, angle, scale, rot_pnt) + msk1 = rotate_image(msk1, angle, scale, rot_pnt) + msk2 = rotate_image(msk2, angle, scale, rot_pnt) + msk3 = rotate_image(msk3, angle, scale, rot_pnt) + msk4 = rotate_image(msk4, angle, scale, rot_pnt) + + input_shape = (512, 512) + crop_size = input_shape[0] + if random.random() > 0.5: + crop_size = random.randint(int(input_shape[0] / 1.1), int(input_shape[0] / 0.9)) + + bst_x0 = random.randint(0, img.shape[1] - crop_size) + bst_y0 = random.randint(0, img.shape[0] - crop_size) + bst_sc = -1 + try_cnt = random.randint(1, 10) + for i in range(try_cnt): + x0 = random.randint(0, img.shape[1] - crop_size) + y0 = random.randint(0, img.shape[0] - crop_size) + _sc = msk2[y0:y0+crop_size, x0:x0+crop_size].sum() * 5 + msk3[y0:y0+crop_size, x0:x0+crop_size].sum() * 5 + msk4[y0:y0+crop_size, x0:x0+crop_size].sum() * 2 + msk1[y0:y0+crop_size, x0:x0+crop_size].sum() + if _sc > bst_sc: + bst_sc = _sc + bst_x0 = x0 + bst_y0 = y0 + x0 = bst_x0 + y0 = bst_y0 + img = img[y0:y0+crop_size, x0:x0+crop_size, :] + img2 = img2[y0:y0+crop_size, x0:x0+crop_size, :] + msk0 = msk0[y0:y0+crop_size, x0:x0+crop_size] + msk1 = msk1[y0:y0+crop_size, x0:x0+crop_size] + msk2 = msk2[y0:y0+crop_size, x0:x0+crop_size] + msk3 = msk3[y0:y0+crop_size, x0:x0+crop_size] + msk4 = msk4[y0:y0+crop_size, x0:x0+crop_size] + + if crop_size != input_shape[0]: + img = cv2.resize(img, input_shape, interpolation=cv2.INTER_LINEAR) + img2 = cv2.resize(img2, input_shape, interpolation=cv2.INTER_LINEAR) + msk0 = cv2.resize(msk0, input_shape, interpolation=cv2.INTER_LINEAR) + msk1 = cv2.resize(msk1, input_shape, interpolation=cv2.INTER_LINEAR) + msk2 = cv2.resize(msk2, input_shape, interpolation=cv2.INTER_LINEAR) + msk3 = cv2.resize(msk3, input_shape, interpolation=cv2.INTER_LINEAR) + msk4 = cv2.resize(msk4, input_shape, interpolation=cv2.INTER_LINEAR) + + + if random.random() > 0.99: + img = shift_channels(img, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5)) + elif random.random() > 0.99: + img2 = shift_channels(img2, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5)) + + if random.random() > 0.99: + img = change_hsv(img, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5)) + elif random.random() > 0.99: + img2 = change_hsv(img2, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5)) + + if random.random() > 0.99: + if random.random() > 0.99: + img = clahe(img) + elif random.random() > 0.99: + img = gauss_noise(img) + elif random.random() > 0.99: + img = cv2.blur(img, (3, 3)) + elif random.random() > 0.99: + if random.random() > 0.99: + img = saturation(img, 0.9 + random.random() * 0.2) + elif random.random() > 0.99: + img = brightness(img, 0.9 + random.random() * 0.2) + elif random.random() > 0.99: + img = contrast(img, 0.9 + random.random() * 0.2) + + if random.random() > 0.99: + if random.random() > 0.99: + img2 = clahe(img2) + elif random.random() > 0.99: + img2 = gauss_noise(img2) + elif random.random() > 0.99: + img2 = cv2.blur(img2, (3, 3)) + elif random.random() > 0.99: + if random.random() > 0.99: + img2 = saturation(img2, 0.9 + random.random() * 0.2) + elif random.random() > 0.99: + img2 = brightness(img2, 0.9 + random.random() * 0.2) + elif random.random() > 0.99: + img2 = contrast(img2, 0.9 + random.random() * 0.2) + + + if random.random() > 0.99: + el_det = self.elastic.to_deterministic() + img = el_det.augment_image(img) + + if random.random() > 0.99: + el_det = self.elastic.to_deterministic() + img2 = el_det.augment_image(img2) + + msk0 = msk0[..., np.newaxis] + msk1 = msk1[..., np.newaxis] + msk2 = msk2[..., np.newaxis] + msk3 = msk3[..., np.newaxis] + msk4 = msk4[..., np.newaxis] + + msk = np.concatenate([msk0, msk1, msk2, msk3, msk4], axis=2) + msk = (msk > 127) + + msk[..., 0] = True + msk[..., 1] = dilation(msk[..., 1], square(5)) + msk[..., 2] = dilation(msk[..., 2], square(5)) + msk[..., 3] = dilation(msk[..., 3], square(5)) + msk[..., 4] = dilation(msk[..., 4], square(5)) + msk[..., 1][msk[..., 2:].max(axis=2)] = False + msk[..., 3][msk[..., 2]] = False + msk[..., 4][msk[..., 2]] = False + msk[..., 4][msk[..., 3]] = False + msk[..., 0][msk[..., 1:].max(axis=2)] = False + msk = msk * 1 + + lbl_msk = msk.argmax(axis=2) + + img = np.concatenate([img, img2], axis=2) + img = preprocess_inputs(img) + + img = torch.from_numpy(img.transpose((2, 0, 1))).float() + msk = torch.from_numpy(msk.transpose((2, 0, 1))).long() + + sample = {'img': img, 'msk': msk, 'lbl_msk': lbl_msk, 'fn': fn} + return sample + +class TrainData_34(Dataset): + def __init__(self, train_idxs): + super().__init__() + self.train_idxs = train_idxs + self.elastic = iaa.ElasticTransformation(alpha=(0.25, 1.2), sigma=0.2) + + def __len__(self): + return len(self.train_idxs) + + def __getitem__(self, idx): + _idx = self.train_idxs[idx] + + fn = all_files[_idx] + + img = cv2.imread(fn, cv2.IMREAD_COLOR) + img2 = cv2.imread(fn.replace('_pre_', '_post_'), cv2.IMREAD_COLOR) + + msk0 = cv2.imread(fn.replace('/shared/ritwik/data/xview2/train/images/', 'masks/'), cv2.IMREAD_UNCHANGED) + lbl_msk1 = cv2.imread(fn.replace('/shared/ritwik/data/xview2/train/images/', 'masks/').replace('_pre_disaster', '_post_disaster'), cv2.IMREAD_UNCHANGED) + + msk1 = np.zeros_like(lbl_msk1) + msk2 = np.zeros_like(lbl_msk1) + msk3 = np.zeros_like(lbl_msk1) + msk4 = np.zeros_like(lbl_msk1) + msk2[lbl_msk1 == 2] = 255 + msk3[lbl_msk1 == 3] = 255 + msk4[lbl_msk1 == 4] = 255 + msk1[lbl_msk1 == 1] = 255 + + if random.random() > 0.7: + img = img[::-1, ...] + img2 = img2[::-1, ...] + msk0 = msk0[::-1, ...] + msk1 = msk1[::-1, ...] + msk2 = msk2[::-1, ...] + msk3 = msk3[::-1, ...] + msk4 = msk4[::-1, ...] + + if random.random() > 0.3: + rot = random.randrange(4) + if rot > 0: + img = np.rot90(img, k=rot) + img2 = np.rot90(img2, k=rot) + msk0 = np.rot90(msk0, k=rot) + msk1 = np.rot90(msk1, k=rot) + msk2 = np.rot90(msk2, k=rot) + msk3 = np.rot90(msk3, k=rot) + msk4 = np.rot90(msk4, k=rot) + + if random.random() > 0.99: + shift_pnt = (random.randint(-320, 320), random.randint(-320, 320)) + img = shift_image(img, shift_pnt) + img2 = shift_image(img2, shift_pnt) + msk0 = shift_image(msk0, shift_pnt) + msk1 = shift_image(msk1, shift_pnt) + msk2 = shift_image(msk2, shift_pnt) + msk3 = shift_image(msk3, shift_pnt) + msk4 = shift_image(msk4, shift_pnt) + + if random.random() > 0.5: + rot_pnt = (img.shape[0] // 2 + random.randint(-320, 320), img.shape[1] // 2 + random.randint(-320, 320)) + scale = 0.9 + random.random() * 0.2 + angle = random.randint(0, 20) - 10 + if (angle != 0) or (scale != 1): + img = rotate_image(img, angle, scale, rot_pnt) + img2 = rotate_image(img2, angle, scale, rot_pnt) + msk0 = rotate_image(msk0, angle, scale, rot_pnt) + msk1 = rotate_image(msk1, angle, scale, rot_pnt) + msk2 = rotate_image(msk2, angle, scale, rot_pnt) + msk3 = rotate_image(msk3, angle, scale, rot_pnt) + msk4 = rotate_image(msk4, angle, scale, rot_pnt) + + input_shape = (608, 608) + crop_size = input_shape[0] + if random.random() > 0.5: + crop_size = random.randint(int(input_shape[0] / 1.1), int(input_shape[0] / 0.9)) + + bst_x0 = random.randint(0, img.shape[1] - crop_size) + bst_y0 = random.randint(0, img.shape[0] - crop_size) + bst_sc = -1 + try_cnt = random.randint(1, 10) + for i in range(try_cnt): + x0 = random.randint(0, img.shape[1] - crop_size) + y0 = random.randint(0, img.shape[0] - crop_size) + _sc = msk2[y0:y0+crop_size, x0:x0+crop_size].sum() * 5 + msk3[y0:y0+crop_size, x0:x0+crop_size].sum() * 5 + msk4[y0:y0+crop_size, x0:x0+crop_size].sum() * 2 + msk1[y0:y0+crop_size, x0:x0+crop_size].sum() + if _sc > bst_sc: + bst_sc = _sc + bst_x0 = x0 + bst_y0 = y0 + x0 = bst_x0 + y0 = bst_y0 + img = img[y0:y0+crop_size, x0:x0+crop_size, :] + img2 = img2[y0:y0+crop_size, x0:x0+crop_size, :] + msk0 = msk0[y0:y0+crop_size, x0:x0+crop_size] + msk1 = msk1[y0:y0+crop_size, x0:x0+crop_size] + msk2 = msk2[y0:y0+crop_size, x0:x0+crop_size] + msk3 = msk3[y0:y0+crop_size, x0:x0+crop_size] + msk4 = msk4[y0:y0+crop_size, x0:x0+crop_size] + + if crop_size != input_shape[0]: + img = cv2.resize(img, input_shape, interpolation=cv2.INTER_LINEAR) + img2 = cv2.resize(img2, input_shape, interpolation=cv2.INTER_LINEAR) + msk0 = cv2.resize(msk0, input_shape, interpolation=cv2.INTER_LINEAR) + msk1 = cv2.resize(msk1, input_shape, interpolation=cv2.INTER_LINEAR) + msk2 = cv2.resize(msk2, input_shape, interpolation=cv2.INTER_LINEAR) + msk3 = cv2.resize(msk3, input_shape, interpolation=cv2.INTER_LINEAR) + msk4 = cv2.resize(msk4, input_shape, interpolation=cv2.INTER_LINEAR) + + + if random.random() > 0.99: + img = shift_channels(img, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5)) + elif random.random() > 0.99: + img2 = shift_channels(img2, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5)) + + if random.random() > 0.99: + img = change_hsv(img, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5)) + elif random.random() > 0.99: + img2 = change_hsv(img2, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5)) + + if random.random() > 0.99: + if random.random() > 0.99: + img = clahe(img) + elif random.random() > 0.99: + img = gauss_noise(img) + elif random.random() > 0.99: + img = cv2.blur(img, (3, 3)) + elif random.random() > 0.99: + if random.random() > 0.99: + img = saturation(img, 0.9 + random.random() * 0.2) + elif random.random() > 0.99: + img = brightness(img, 0.9 + random.random() * 0.2) + elif random.random() > 0.99: + img = contrast(img, 0.9 + random.random() * 0.2) + + if random.random() > 0.99: + if random.random() > 0.99: + img2 = clahe(img2) + elif random.random() > 0.99: + img2 = gauss_noise(img2) + elif random.random() > 0.99: + img2 = cv2.blur(img2, (3, 3)) + elif random.random() > 0.99: + if random.random() > 0.99: + img2 = saturation(img2, 0.9 + random.random() * 0.2) + elif random.random() > 0.99: + img2 = brightness(img2, 0.9 + random.random() * 0.2) + elif random.random() > 0.99: + img2 = contrast(img2, 0.9 + random.random() * 0.2) + + + if random.random() > 0.99: + el_det = self.elastic.to_deterministic() + img = el_det.augment_image(img) + + if random.random() > 0.99: + el_det = self.elastic.to_deterministic() + img2 = el_det.augment_image(img2) + + msk0 = msk0[..., np.newaxis] + msk1 = msk1[..., np.newaxis] + msk2 = msk2[..., np.newaxis] + msk3 = msk3[..., np.newaxis] + msk4 = msk4[..., np.newaxis] + + msk = np.concatenate([msk0, msk1, msk2, msk3, msk4], axis=2) + msk = (msk > 127) + + msk[..., 0] = False + msk[..., 1] = dilation(msk[..., 1], square(5)) + msk[..., 2] = dilation(msk[..., 2], square(5)) + msk[..., 3] = dilation(msk[..., 3], square(5)) + msk[..., 4] = dilation(msk[..., 4], square(5)) + msk[..., 1][msk[..., 2:].max(axis=2)] = False + msk[..., 3][msk[..., 2]] = False + msk[..., 4][msk[..., 2]] = False + msk[..., 4][msk[..., 3]] = False + msk[..., 0][msk[..., 1:].max(axis=2)] = True + msk = msk * 1 + + lbl_msk = msk.argmax(axis=2) + + img = np.concatenate([img, img2], axis=2) + img = preprocess_inputs(img) + + img = torch.from_numpy(img.transpose((2, 0, 1))).float() + msk = torch.from_numpy(msk.transpose((2, 0, 1))).long() + + sample = {'img': img, 'msk': msk, 'lbl_msk': lbl_msk, 'fn': fn} + return sample + + +class ValData(Dataset): + def __init__(self, image_idxs): + super().__init__() + self.image_idxs = image_idxs + + def __len__(self): + return len(self.image_idxs) + + def __getitem__(self, idx): + _idx = self.image_idxs[idx] + + fn = all_files[_idx] + + img = cv2.imread(fn, cv2.IMREAD_COLOR) + img2 = cv2.imread(fn.replace('_pre_', '_post_'), cv2.IMREAD_COLOR) + + msk0 = cv2.imread(fn.replace('/shared/ritwik/data/xview2/train/images/', 'masks/'), cv2.IMREAD_UNCHANGED) + lbl_msk1 = cv2.imread(fn.replace('/shared/ritwik/data/xview2/train/images/', 'masks/').replace('_pre_disaster', '_post_disaster'), cv2.IMREAD_UNCHANGED) + msk_loc = cv2.imread(path.join(loc_folder, '{0}.png'.format(fn.split('/')[-1].replace('.png', '_part1.png'))), cv2.IMREAD_UNCHANGED) > (0.3*255) + + msk1 = np.zeros_like(lbl_msk1) + msk2 = np.zeros_like(lbl_msk1) + msk3 = np.zeros_like(lbl_msk1) + msk4 = np.zeros_like(lbl_msk1) + msk1[lbl_msk1 == 1] = 255 + msk2[lbl_msk1 == 2] = 255 + msk3[lbl_msk1 == 3] = 255 + msk4[lbl_msk1 == 4] = 255 + + msk0 = msk0[..., np.newaxis] + msk1 = msk1[..., np.newaxis] + msk2 = msk2[..., np.newaxis] + msk3 = msk3[..., np.newaxis] + msk4 = msk4[..., np.newaxis] + + msk = np.concatenate([msk0, msk1, msk2, msk3, msk4], axis=2) + msk = (msk > 127) + + msk = msk * 1 + + lbl_msk = msk[..., 1:].argmax(axis=2) + + img = np.concatenate([img, img2], axis=2) + img = preprocess_inputs(img) + + img = torch.from_numpy(img.transpose((2, 0, 1))).float() + msk = torch.from_numpy(msk.transpose((2, 0, 1))).long() + + sample = {'img': img, 'msk': msk, 'lbl_msk': lbl_msk, 'fn': fn, 'msk_loc': msk_loc} + return sample + + +def validate(net, data_loader, res34=False): + dices0 = [] + + dim = 4 if res34 else 5 + tp = np.zeros((dim,)) + fp = np.zeros((dim,)) + fn = np.zeros((dim,)) + + _thr = 0.3 + + with torch.no_grad(): + for i, sample in enumerate(tqdm(data_loader)): + msks = sample["msk"].numpy() + lbl_msk = sample["lbl_msk"].numpy() + imgs = sample["img"].cuda(non_blocking=True) + msk_loc = sample["msk_loc"].numpy() * 1 + out = model(imgs) + + msk_pred = msk_loc + if res34: + msk_damage_pred = torch.sigmoid(out).cpu().numpy()[:, 1:, ...] + + for j in range(msks.shape[0]): + dices0.append(dice(msks[j, 0], msk_pred[j] > _thr)) + + targ = lbl_msk[j][msks[j, 0] > 0] + pred = msk_damage_pred[j].argmax(axis=0) + pred = pred * (msk_pred[j] > _thr) + pred = pred[msks[j, 0] > 0] + for c in range(4): + tp[c] += np.logical_and(pred == c, targ == c).sum() + fn[c] += np.logical_and(pred != c, targ == c).sum() + fp[c] += np.logical_and(pred == c, targ != c).sum() + + else: + msk_damage_pred = torch.softmax(out, dim=1).cpu().numpy()[:, 1:, ...] + + for j in range(msks.shape[0]): + tp[4] += np.logical_and(msks[j, 0] > 0, msk_pred[j] > 0).sum() + fn[4] += np.logical_and(msks[j, 0] < 1, msk_pred[j] > 0).sum() + fp[4] += np.logical_and(msks[j, 0] > 0, msk_pred[j] < 1).sum() + + + targ = lbl_msk[j][msks[j, 0] > 0] + pred = msk_damage_pred[j].argmax(axis=0) + pred = pred * (msk_pred[j] > _thr) + pred = pred[msks[j, 0] > 0] + for c in range(4): + tp[c] += np.logical_and(pred == c, targ == c).sum() + fn[c] += np.logical_and(pred != c, targ == c).sum() + fp[c] += np.logical_and(pred == c, targ != c).sum() + if res34: + d0 = np.mean(dices0) + else: + d0 = 2 * tp[4] / (2 * tp[4] + fp[4] + fn[4]) + + f1_sc = np.zeros((4,)) + for c in range(4): + f1_sc[c] = 2 * tp[c] / (2 * tp[c] + fp[c] + fn[c]) + + f1 = 4 / np.sum(1.0 / (f1_sc + 1e-6)) + + sc = 0.3 * d0 + 0.7 * f1 + print("Val Score: {}, Dice: {}, F1: {}, F1_0: {}, F1_1: {}, F1_2: {}, F1_3: {}".format(sc, d0, f1, f1_sc[0], f1_sc[1], f1_sc[2], f1_sc[3])) + return sc + + +def evaluate_val(data_val, best_score, model, snapshot_name, current_epoch, res34): + model = model.eval() + d = validate(model, data_loader=data_val, res34=res34) + + if d > best_score: + torch.save({ + 'epoch': current_epoch + 1, + 'state_dict': model.state_dict(), + 'best_score': d, + }, path.join(models_folder, snapshot_name + '_best')) + best_score = d + + print("score: {}\tscore_best: {}".format(d, best_score)) + return best_score + + +def train_epoch(current_epoch, seg_loss, ce_loss, model, optimizer, scheduler, train_data_loader, model_name): + losses = AverageMeter() + losses1 = AverageMeter() + + dices = AverageMeter() + + iterator = tqdm(train_data_loader) + model.train() + for i, sample in enumerate(iterator): + imgs = sample["img"].cuda(non_blocking=True) + msks = sample["msk"].cuda(non_blocking=True) + if model_name != 'res34': + lbl_msk = sample["lbl_msk"].cuda(non_blocking=True) + + out = model(imgs) + + loss0 = seg_loss(out[:, 0, ...], msks[:, 0, ...]) + loss1 = seg_loss(out[:, 1, ...], msks[:, 1, ...]) + loss2 = seg_loss(out[:, 2, ...], msks[:, 2, ...]) + loss3 = seg_loss(out[:, 3, ...], msks[:, 3, ...]) + loss4 = seg_loss(out[:, 4, ...], msks[:, 4, ...]) + + if model_name != 'res34': + loss5 = ce_loss(out, lbl_msk) + + if model_name == 'res34': + loss = 0.05 * loss0 + 0.2 * loss1 + 0.8 * loss2 + 0.7 * loss3 + 0.4 * loss4 + elif model_name == 'res50': + loss = 0.1 * loss0 + 0.1 * loss1 + 0.3 * loss2 + 0.3 * loss3 + 0.2 * loss4 + loss5 * 11 + elif model_name == 'dpn92': + loss = 0.1 * loss0 + 0.1 * loss1 + 0.5 * loss2 + 0.3 * loss3 + 0.2 * loss4 + loss5 * 11 + else: + loss = 0.1 * loss0 + 0.1 * loss1 + 0.6 * loss2 + 0.3 * loss3 + 0.2 * loss4 + loss5 * 8 + + with torch.no_grad(): + if model_name != 'res34': + _probs = 1 - torch.sigmoid(out[:, 0, ...]) + dice_sc = 1 - dice_round(_probs, 1 - msks[:, 0, ...]) + else: + _probs = torch.sigmoid(out[:, 0, ...]) + dice_sc = 1 - dice_round(_probs, msks[:, 0, ...]) + + losses.update(loss.item(), imgs.size(0)) + if model_name != 'res34': + losses1.update(loss5.item(), imgs.size(0)) #loss5 + else: + losses1.update(loss2.item(), imgs.size(0)) + + dices.update(dice_sc, imgs.size(0)) + + if model_name != 'res34': + iterator.set_description( + "epoch: {}; lr {:.7f}; Loss {loss.val:.4f} ({loss.avg:.4f}); cce_loss {loss1.val:.4f} ({loss1.avg:.4f}); Dice {dice.val:.4f} ({dice.avg:.4f})".format( + current_epoch, scheduler.get_lr()[-1], loss=losses, loss1=losses1, dice=dices)) + else: + iterator.set_description( + "epoch: {}; lr {:.7f}; Loss {loss.val:.4f} ({loss.avg:.4f}); loss2 {loss1.val:.4f} ({loss1.avg:.4f}); Dice {dice.val:.4f} ({dice.avg:.4f})".format( + current_epoch, scheduler.get_lr()[-1], loss=losses, loss1=losses1, dice=dices)) + + optimizer.zero_grad() + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 0.999) + optimizer.step() + + scheduler.step(current_epoch) + + if model_name != 'res34': + print("epoch: {}; lr {:.7f}; Loss {loss.avg:.4f}; CCE_loss {loss1.avg:.4f}; Dice {dice.avg:.4f}".format( + current_epoch, scheduler.get_lr()[-1], loss=losses, loss1=losses1, dice=dices)) + else: + print("epoch: {}; lr {:.7f}; Loss {loss.avg:.4f}; loss2 {loss1.avg:.4f}; Dice {dice.avg:.4f}".format( + current_epoch, scheduler.get_lr()[-1], loss=losses, loss1=losses1, dice=dices)) + + +def finetune(train_idxs, val_idxs, model_name, seed, snapshot_name, snap_to_load, lr, focal, epoch_num, batch_size, val_batch_size): + assert model_name in ['res34', 'res50', 'dpn92', 'se154'] + + steps_per_epoch = len(train_idxs) // batch_size + validation_steps = len(val_idxs) // val_batch_size + + print('finetuning', model_name, 'steps_per_epoch', steps_per_epoch, 'validation_steps', validation_steps) + + if model_name == 'res34': + data_train = TrainData_34(train_idxs) + else: + data_train = TrainData(train_idxs) + val_train = ValData(val_idxs) + + train_data_loader = DataLoader(data_train, batch_size=batch_size, num_workers=6, shuffle=True, pin_memory=False, drop_last=True) + val_data_loader = DataLoader(val_train, batch_size=val_batch_size, num_workers=6, shuffle=False, pin_memory=False) + + model = SeResNext50_Unet_Double().cuda() + if model_name == 'dpn92': + model = Dpn92_Unet_Double().cuda() + elif model_name == 'res34': + model = Res34_Unet_Double().cuda() + elif model_name == 'se154': + model = SeNet154_Unet_Double().cuda() + + params = model.parameters() + + optimizer = AdamW(params, lr=lr, weight_decay=1e-6) + + model, optimizer = amp.initialize(model, optimizer, opt_level="O1") + + scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[1, 2, 3, 4, 5, 7, 9, 11, 17, 23, 29, 33, 47, 50, 60, 70, 90, 110, 130, 150, 170, 180, 190], gamma=0.5) + + model = nn.DataParallel(model).cuda() + + print("=> loading checkpoint '{}'".format(snap_to_load)) + checkpoint = torch.load(path.join(models_folder, snap_to_load), map_location='cpu') + loaded_dict = checkpoint['state_dict'] + sd = model.state_dict() + for k in model.state_dict(): + if k in loaded_dict and sd[k].size() == loaded_dict[k].size(): + sd[k] = loaded_dict[k] + loaded_dict = sd + model.load_state_dict(loaded_dict) + print("loaded checkpoint '{}' (epoch {}, best_score {})" + .format(snap_to_load, checkpoint['epoch'], checkpoint['best_score'])) + del loaded_dict + del sd + del checkpoint + gc.collect() + torch.cuda.empty_cache() + + if focal != None: + seg_loss = ComboLoss({'dice': 0.5, 'focal': focal}, per_image=False).cuda() + else: + seg_loss = ComboLoss({'dice': 0.5}, per_image=False).cuda() + ce_loss = nn.CrossEntropyLoss().cuda() + + best_score = 0 + torch.cuda.empty_cache() + for epoch in range(epoch_num): + train_epoch(epoch, seg_loss, ce_loss, model, optimizer, scheduler, train_data_loader, model_name) + torch.cuda.empty_cache() + best_score = evaluate_val(val_data_loader, best_score, model, snapshot_name, epoch, res34=(model_name=='res34')) + + elapsed = timeit.default_timer() - t0 + print('Time: {:.3f} min'.format(elapsed / 60)) + + +if __name__ == '__main__': + t0 = timeit.default_timer() + + makedirs(models_folder, exist_ok=True) + + seed = int(sys.argv[1]) + # vis_dev = sys.argv[2] + + # os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' + # os.environ["CUDA_VISIBLE_DEVICES"] = vis_dev + + cudnn.benchmark = True + + file_classes = [] + for fn in tqdm(all_files): + fl = np.zeros((4,), dtype=bool) + msk1 = cv2.imread(fn.replace('/shared/ritwik/data/xview2/train/images/', 'masks/').replace('_pre_disaster', '_post_disaster'), cv2.IMREAD_UNCHANGED) + for c in range(1, 5): + fl[c-1] = c in msk1 + file_classes.append(fl) + file_classes = np.asarray(file_classes) + + _, val_idxs = train_test_split(np.arange(train_len), test_size=0.1, random_state=seed) + + np.random.seed(seed + 131313) + random.seed(seed + 131313) + + train_idxs = [] + for i in np.arange(len(all_files)): + train_idxs.append(i) + if file_classes[i, 1:].max(): + train_idxs.append(i) + train_idxs = np.asarray(train_idxs) + + finetune(train_idxs=train_idxs, val_idxs=val_idxs, model_name='res50', seed=seed, snapshot_name='res50_cls_cce_{}_tuned'.format(seed), snap_to_load='res50_cls_cce_{}_0_best'.format(seed), lr=0.00001, focal=2.0, epoch_num=2, batch_size=16, val_batch_size=4) + + finetune(train_idxs=train_idxs, val_idxs=val_idxs, model_name='dpn92', seed=seed, snapshot_name='dpn92_cls_cce_{}_tuned'.format(seed), snap_to_load='dpn92_cls_cce_{}_1_best'.format(seed), lr=0.000008, focal=5.0, epoch_num=1, batch_size=12, val_batch_size=4) + + finetune(train_idxs=train_idxs, val_idxs=val_idxs, model_name='res34', seed=seed, snapshot_name='res34_cls2_{}_tuned'.format(seed), snap_to_load='res34_cls2_{}_0_best'.format(seed), lr=0.000008, focal=12.0, epoch_num=3, batch_size=16, val_batch_size=8) + + finetune(train_idxs=train_idxs, val_idxs=val_idxs, model_name='se154', seed=seed, snapshot_name='se154_cls_cce_{}_tuned'.format(seed), snap_to_load='se154_cls_cce_{}_1_best'.format(seed), lr=0.000008, focal=None, epoch_num=2, batch_size=8, val_batch_size=2) From 04532fa20343e8cb0b2814449d958a9159fc9064 Mon Sep 17 00:00:00 2001 From: Paris Zhang Date: Tue, 21 Feb 2023 19:18:33 +0000 Subject: [PATCH 2/2] remove repetitive code --- tune_cls.py | 203 ++-------------------------------------------------- 1 file changed, 5 insertions(+), 198 deletions(-) diff --git a/tune_cls.py b/tune_cls.py index 2d1a186..bed5390 100644 --- a/tune_cls.py +++ b/tune_cls.py @@ -64,7 +64,8 @@ class TrainData(Dataset): - def __init__(self, train_idxs): + def __init__(self, train_idxs, **kwargs): + self.res34 = kwargs.pop('res34') super().__init__() self.train_idxs = train_idxs self.elastic = iaa.ElasticTransformation(alpha=(0.25, 1.2), sigma=0.2) @@ -230,7 +231,7 @@ def __getitem__(self, idx): msk = np.concatenate([msk0, msk1, msk2, msk3, msk4], axis=2) msk = (msk > 127) - msk[..., 0] = True + msk[..., 0] = (not self.res34) msk[..., 1] = dilation(msk[..., 1], square(5)) msk[..., 2] = dilation(msk[..., 2], square(5)) msk[..., 3] = dilation(msk[..., 3], square(5)) @@ -239,7 +240,7 @@ def __getitem__(self, idx): msk[..., 3][msk[..., 2]] = False msk[..., 4][msk[..., 2]] = False msk[..., 4][msk[..., 3]] = False - msk[..., 0][msk[..., 1:].max(axis=2)] = False + msk[..., 0][msk[..., 1:].max(axis=2)] = self.res34 msk = msk * 1 lbl_msk = msk.argmax(axis=2) @@ -253,197 +254,6 @@ def __getitem__(self, idx): sample = {'img': img, 'msk': msk, 'lbl_msk': lbl_msk, 'fn': fn} return sample -class TrainData_34(Dataset): - def __init__(self, train_idxs): - super().__init__() - self.train_idxs = train_idxs - self.elastic = iaa.ElasticTransformation(alpha=(0.25, 1.2), sigma=0.2) - - def __len__(self): - return len(self.train_idxs) - - def __getitem__(self, idx): - _idx = self.train_idxs[idx] - - fn = all_files[_idx] - - img = cv2.imread(fn, cv2.IMREAD_COLOR) - img2 = cv2.imread(fn.replace('_pre_', '_post_'), cv2.IMREAD_COLOR) - - msk0 = cv2.imread(fn.replace('/shared/ritwik/data/xview2/train/images/', 'masks/'), cv2.IMREAD_UNCHANGED) - lbl_msk1 = cv2.imread(fn.replace('/shared/ritwik/data/xview2/train/images/', 'masks/').replace('_pre_disaster', '_post_disaster'), cv2.IMREAD_UNCHANGED) - - msk1 = np.zeros_like(lbl_msk1) - msk2 = np.zeros_like(lbl_msk1) - msk3 = np.zeros_like(lbl_msk1) - msk4 = np.zeros_like(lbl_msk1) - msk2[lbl_msk1 == 2] = 255 - msk3[lbl_msk1 == 3] = 255 - msk4[lbl_msk1 == 4] = 255 - msk1[lbl_msk1 == 1] = 255 - - if random.random() > 0.7: - img = img[::-1, ...] - img2 = img2[::-1, ...] - msk0 = msk0[::-1, ...] - msk1 = msk1[::-1, ...] - msk2 = msk2[::-1, ...] - msk3 = msk3[::-1, ...] - msk4 = msk4[::-1, ...] - - if random.random() > 0.3: - rot = random.randrange(4) - if rot > 0: - img = np.rot90(img, k=rot) - img2 = np.rot90(img2, k=rot) - msk0 = np.rot90(msk0, k=rot) - msk1 = np.rot90(msk1, k=rot) - msk2 = np.rot90(msk2, k=rot) - msk3 = np.rot90(msk3, k=rot) - msk4 = np.rot90(msk4, k=rot) - - if random.random() > 0.99: - shift_pnt = (random.randint(-320, 320), random.randint(-320, 320)) - img = shift_image(img, shift_pnt) - img2 = shift_image(img2, shift_pnt) - msk0 = shift_image(msk0, shift_pnt) - msk1 = shift_image(msk1, shift_pnt) - msk2 = shift_image(msk2, shift_pnt) - msk3 = shift_image(msk3, shift_pnt) - msk4 = shift_image(msk4, shift_pnt) - - if random.random() > 0.5: - rot_pnt = (img.shape[0] // 2 + random.randint(-320, 320), img.shape[1] // 2 + random.randint(-320, 320)) - scale = 0.9 + random.random() * 0.2 - angle = random.randint(0, 20) - 10 - if (angle != 0) or (scale != 1): - img = rotate_image(img, angle, scale, rot_pnt) - img2 = rotate_image(img2, angle, scale, rot_pnt) - msk0 = rotate_image(msk0, angle, scale, rot_pnt) - msk1 = rotate_image(msk1, angle, scale, rot_pnt) - msk2 = rotate_image(msk2, angle, scale, rot_pnt) - msk3 = rotate_image(msk3, angle, scale, rot_pnt) - msk4 = rotate_image(msk4, angle, scale, rot_pnt) - - input_shape = (608, 608) - crop_size = input_shape[0] - if random.random() > 0.5: - crop_size = random.randint(int(input_shape[0] / 1.1), int(input_shape[0] / 0.9)) - - bst_x0 = random.randint(0, img.shape[1] - crop_size) - bst_y0 = random.randint(0, img.shape[0] - crop_size) - bst_sc = -1 - try_cnt = random.randint(1, 10) - for i in range(try_cnt): - x0 = random.randint(0, img.shape[1] - crop_size) - y0 = random.randint(0, img.shape[0] - crop_size) - _sc = msk2[y0:y0+crop_size, x0:x0+crop_size].sum() * 5 + msk3[y0:y0+crop_size, x0:x0+crop_size].sum() * 5 + msk4[y0:y0+crop_size, x0:x0+crop_size].sum() * 2 + msk1[y0:y0+crop_size, x0:x0+crop_size].sum() - if _sc > bst_sc: - bst_sc = _sc - bst_x0 = x0 - bst_y0 = y0 - x0 = bst_x0 - y0 = bst_y0 - img = img[y0:y0+crop_size, x0:x0+crop_size, :] - img2 = img2[y0:y0+crop_size, x0:x0+crop_size, :] - msk0 = msk0[y0:y0+crop_size, x0:x0+crop_size] - msk1 = msk1[y0:y0+crop_size, x0:x0+crop_size] - msk2 = msk2[y0:y0+crop_size, x0:x0+crop_size] - msk3 = msk3[y0:y0+crop_size, x0:x0+crop_size] - msk4 = msk4[y0:y0+crop_size, x0:x0+crop_size] - - if crop_size != input_shape[0]: - img = cv2.resize(img, input_shape, interpolation=cv2.INTER_LINEAR) - img2 = cv2.resize(img2, input_shape, interpolation=cv2.INTER_LINEAR) - msk0 = cv2.resize(msk0, input_shape, interpolation=cv2.INTER_LINEAR) - msk1 = cv2.resize(msk1, input_shape, interpolation=cv2.INTER_LINEAR) - msk2 = cv2.resize(msk2, input_shape, interpolation=cv2.INTER_LINEAR) - msk3 = cv2.resize(msk3, input_shape, interpolation=cv2.INTER_LINEAR) - msk4 = cv2.resize(msk4, input_shape, interpolation=cv2.INTER_LINEAR) - - - if random.random() > 0.99: - img = shift_channels(img, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5)) - elif random.random() > 0.99: - img2 = shift_channels(img2, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5)) - - if random.random() > 0.99: - img = change_hsv(img, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5)) - elif random.random() > 0.99: - img2 = change_hsv(img2, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5)) - - if random.random() > 0.99: - if random.random() > 0.99: - img = clahe(img) - elif random.random() > 0.99: - img = gauss_noise(img) - elif random.random() > 0.99: - img = cv2.blur(img, (3, 3)) - elif random.random() > 0.99: - if random.random() > 0.99: - img = saturation(img, 0.9 + random.random() * 0.2) - elif random.random() > 0.99: - img = brightness(img, 0.9 + random.random() * 0.2) - elif random.random() > 0.99: - img = contrast(img, 0.9 + random.random() * 0.2) - - if random.random() > 0.99: - if random.random() > 0.99: - img2 = clahe(img2) - elif random.random() > 0.99: - img2 = gauss_noise(img2) - elif random.random() > 0.99: - img2 = cv2.blur(img2, (3, 3)) - elif random.random() > 0.99: - if random.random() > 0.99: - img2 = saturation(img2, 0.9 + random.random() * 0.2) - elif random.random() > 0.99: - img2 = brightness(img2, 0.9 + random.random() * 0.2) - elif random.random() > 0.99: - img2 = contrast(img2, 0.9 + random.random() * 0.2) - - - if random.random() > 0.99: - el_det = self.elastic.to_deterministic() - img = el_det.augment_image(img) - - if random.random() > 0.99: - el_det = self.elastic.to_deterministic() - img2 = el_det.augment_image(img2) - - msk0 = msk0[..., np.newaxis] - msk1 = msk1[..., np.newaxis] - msk2 = msk2[..., np.newaxis] - msk3 = msk3[..., np.newaxis] - msk4 = msk4[..., np.newaxis] - - msk = np.concatenate([msk0, msk1, msk2, msk3, msk4], axis=2) - msk = (msk > 127) - - msk[..., 0] = False - msk[..., 1] = dilation(msk[..., 1], square(5)) - msk[..., 2] = dilation(msk[..., 2], square(5)) - msk[..., 3] = dilation(msk[..., 3], square(5)) - msk[..., 4] = dilation(msk[..., 4], square(5)) - msk[..., 1][msk[..., 2:].max(axis=2)] = False - msk[..., 3][msk[..., 2]] = False - msk[..., 4][msk[..., 2]] = False - msk[..., 4][msk[..., 3]] = False - msk[..., 0][msk[..., 1:].max(axis=2)] = True - msk = msk * 1 - - lbl_msk = msk.argmax(axis=2) - - img = np.concatenate([img, img2], axis=2) - img = preprocess_inputs(img) - - img = torch.from_numpy(img.transpose((2, 0, 1))).float() - msk = torch.from_numpy(msk.transpose((2, 0, 1))).long() - - sample = {'img': img, 'msk': msk, 'lbl_msk': lbl_msk, 'fn': fn} - return sample - - class ValData(Dataset): def __init__(self, image_idxs): super().__init__() @@ -662,10 +472,7 @@ def finetune(train_idxs, val_idxs, model_name, seed, snapshot_name, snap_to_load print('finetuning', model_name, 'steps_per_epoch', steps_per_epoch, 'validation_steps', validation_steps) - if model_name == 'res34': - data_train = TrainData_34(train_idxs) - else: - data_train = TrainData(train_idxs) + data_train = TrainData(train_idxs, res34=(model_name=='res34')) val_train = ValData(val_idxs) train_data_loader = DataLoader(data_train, batch_size=batch_size, num_workers=6, shuffle=True, pin_memory=False, drop_last=True)