diff --git a/model.py b/model.py index 7df5b6889..606a74ecc 100644 --- a/model.py +++ b/model.py @@ -9,8 +9,45 @@ from keras.callbacks import ModelCheckpoint, LearningRateScheduler from keras import backend as keras +def jaccard_distance(y_true, y_pred, smooth=100): + intersection = K.sum(K.abs(y_true * y_pred), axis=-1) + sum_ = K.sum(K.abs(y_true) + K.abs(y_pred), axis=-1) + jac = (intersection + smooth) / (sum_ - intersection + smooth) + return (1 - jac) * smooth -def unet(pretrained_weights = None,input_size = (256,256,1)): +def dice_coef(y_true, y_pred): + smooth = 1 + y_true_f = K.flatten(y_true) + y_pred_f = K.flatten(y_pred) + intersection = K.sum(y_true_f * y_pred_f) + return (2. * intersection +smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) +smooth) + +def dice_coef_loss(y_true, y_pred): + print("dice loss") + return 1-dice_coef(y_true, y_pred) +def specificity(y_true,y_pred): + specificity=0 + y_true_f = K.flatten(y_true) + y_pred_f = K.flatten(y_pred) + TP=K.sum(y_true_f*y_pred_f) + TN=K.sum((1-y_true_f)*(1-y_pred_f)) + FP=K.sum((1-y_true_f)*y_pred_f) + FN=K.sum(y_true_f*(1-y_pred_f)) + specificity=(TN)/((TN+FP)) + return specificity +def sensitivity(y_true,y_pred): + sensitivity=0 + y_true_f = K.flatten(y_true) + y_pred_f = K.flatten(y_pred) + TP=K.sum(y_true_f*y_pred_f) + TN=K.sum((1-y_true_f)*(1-y_pred_f)) + FP=K.sum((1-y_true_f)*y_pred_f) + FN=K.sum(y_true_f*(1-y_pred_f)) + sensitivity=(TP)/((TP+FN)) + + return sensitivity + +def unet(pretrained_weights = None,input_size = (256,256,3)): inputs = Input(input_size) conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs) conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1) @@ -54,7 +91,7 @@ def unet(pretrained_weights = None,input_size = (256,256,1)): model = Model(input = inputs, output = conv10) - model.compile(optimizer = Adam(lr = 1e-4), loss = 'binary_crossentropy', metrics = ['accuracy']) + model.compile(optimizer = Adam(lr = 1e-4), loss = dice_coef_loss, metrics = ['accuracy']) #model.summary() @@ -62,5 +99,3 @@ def unet(pretrained_weights = None,input_size = (256,256,1)): model.load_weights(pretrained_weights) return model - -