Skip to content

Commit

Permalink
feat: Upsample -> ConvTranspose2dUNet3+ #25
Browse files Browse the repository at this point in the history
  • Loading branch information
MCG committed Nov 22, 2024
1 parent 1b34839 commit 51412da
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 52 deletions.
181 changes: 137 additions & 44 deletions UNet3+/Code/Model/FixedModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from Util.SetSeed import set_seed
from .layer import unetConv2, BottleNeck
import torchvision.models as models

set_seed()

class UNet_3Plus_DeepSup(nn.Module):
Expand All @@ -24,23 +25,22 @@ def __init__(self, in_channels=3, n_classes=1, feature_scale=4, is_deconv=True,
## -------------Encoder--------------
self.conv1 = nn.Sequential(
unetConv2(self.in_channels, filters[0], self.is_batchnorm),
#nn.MaxPool2d(kernel_size=2),
BottleNeck(64, 128) # BottleNeck 추가
nn.Dropout(p=0.05),
BottleNeck(filters[0], filters[0] // 2)
)

self.conv2 = nn.Sequential(
nn.MaxPool2d(kernel_size=2),
unetConv2(filters[0], filters[1], self.is_batchnorm),
BottleNeck(128, 256) # BottleNeck 추가
nn.Dropout(p=0.1),
BottleNeck(filters[1], filters[1] // 2)
)

self.conv3 = nn.Sequential(
nn.MaxPool2d(kernel_size=2),
unetConv2(filters[1], filters[2], self.is_batchnorm),
BottleNeck(256, 512) # BottleNeck 추가
nn.Dropout(p=0.15),
BottleNeck(filters[2], filters[2] // 2)
)
self.conv4 = self.resnet.layer2

self.conv5 = self.resnet.layer3
## -------------Decoder--------------
self.CatChannels = filters[0]
Expand Down Expand Up @@ -71,9 +71,14 @@ def __init__(self, in_channels=3, n_classes=1, feature_scale=4, is_deconv=True,
self.h4_Cat_hd4_bn = nn.BatchNorm2d(self.CatChannels)
self.h4_Cat_hd4_relu = nn.ReLU(inplace=True)

# hd5->20*20, hd4->40*40, Upsample 2 times
self.hd5_UT_hd4 = nn.Upsample(scale_factor=2, mode='bilinear') # 14*14
self.hd5_UT_hd4_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)
# hd5->20*20, hd4->40*40, Upsample 2 times (Using ConvTranspose2d)
self.hd5_UT_hd4 = nn.ConvTranspose2d(
in_channels=filters[4], # 입력 채널
out_channels=self.CatChannels, # 출력 채널
kernel_size=2, # 업샘플링 크기
stride=2 # 2배 업샘플링
)
self.hd5_UT_hd4_conv = nn.Conv2d(self.CatChannels, self.CatChannels, 3, padding=1)
self.hd5_UT_hd4_bn = nn.BatchNorm2d(self.CatChannels)
self.hd5_UT_hd4_relu = nn.ReLU(inplace=True)

Expand All @@ -82,6 +87,7 @@ def __init__(self, in_channels=3, n_classes=1, feature_scale=4, is_deconv=True,
self.bn4d_1 = nn.BatchNorm2d(self.UpChannels)
self.relu4d_1 = nn.ReLU(inplace=True)


'''stage 3d'''
# h1->320*320, hd3->80*80, Pooling 4 times
self.h1_PT_hd3 = nn.MaxPool2d(4, 4, ceil_mode=True)
Expand All @@ -100,15 +106,25 @@ def __init__(self, in_channels=3, n_classes=1, feature_scale=4, is_deconv=True,
self.h3_Cat_hd3_bn = nn.BatchNorm2d(self.CatChannels)
self.h3_Cat_hd3_relu = nn.ReLU(inplace=True)

# hd4->40*40, hd4->80*80, Upsample 2 times
self.hd4_UT_hd3 = nn.Upsample(scale_factor=2, mode='bilinear') # 14*14
self.hd4_UT_hd3_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
# hd4->40*40, hd4->80*80, Upsample 2 times (Using ConvTranspose2d)
self.hd4_UT_hd3 = nn.ConvTranspose2d(
in_channels=self.UpChannels, # 입력 채널
out_channels=self.CatChannels, # 출력 채널
kernel_size=2, # 업샘플링 크기
stride=2 # 2배 업샘플링
)
self.hd4_UT_hd3_conv = nn.Conv2d(self.CatChannels, self.CatChannels, 3, padding=1)
self.hd4_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels)
self.hd4_UT_hd3_relu = nn.ReLU(inplace=True)

# hd5->20*20, hd4->80*80, Upsample 4 times
self.hd5_UT_hd3 = nn.Upsample(scale_factor=4, mode='bilinear') # 14*14
self.hd5_UT_hd3_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)
# hd5->20*20, hd4->80*80, Upsample 4 times (Using ConvTranspose2d)
self.hd5_UT_hd3 = nn.ConvTranspose2d(
in_channels=filters[4], # 입력 채널
out_channels=self.CatChannels, # 출력 채널
kernel_size=4, # 업샘플링 크기
stride=4 # 4배 업샘플링
)
self.hd5_UT_hd3_conv = nn.Conv2d(self.CatChannels, self.CatChannels, 3, padding=1)
self.hd5_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels)
self.hd5_UT_hd3_relu = nn.ReLU(inplace=True)

Expand All @@ -117,6 +133,7 @@ def __init__(self, in_channels=3, n_classes=1, feature_scale=4, is_deconv=True,
self.bn3d_1 = nn.BatchNorm2d(self.UpChannels)
self.relu3d_1 = nn.ReLU(inplace=True)


'''stage 2d '''
# h1->320*320, hd2->160*160, Pooling 2 times
self.h1_PT_hd2 = nn.MaxPool2d(2, 2, ceil_mode=True)
Expand All @@ -129,21 +146,36 @@ def __init__(self, in_channels=3, n_classes=1, feature_scale=4, is_deconv=True,
self.h2_Cat_hd2_bn = nn.BatchNorm2d(self.CatChannels)
self.h2_Cat_hd2_relu = nn.ReLU(inplace=True)

# hd3->80*80, hd2->160*160, Upsample 2 times
self.hd3_UT_hd2 = nn.Upsample(scale_factor=2, mode='bilinear') # 14*14
self.hd3_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
# hd3->80*80, hd2->160*160, Upsample 2 times (Using ConvTranspose2d)
self.hd3_UT_hd2 = nn.ConvTranspose2d(
in_channels=self.UpChannels, # 입력 채널
out_channels=self.CatChannels, # 출력 채널
kernel_size=2, # 업샘플링 크기
stride=2 # 2배 업샘플링
)
self.hd3_UT_hd2_conv = nn.Conv2d(self.CatChannels, self.CatChannels, 3, padding=1)
self.hd3_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels)
self.hd3_UT_hd2_relu = nn.ReLU(inplace=True)

# hd4->40*40, hd2->160*160, Upsample 4 times
self.hd4_UT_hd2 = nn.Upsample(scale_factor=4, mode='bilinear') # 14*14
self.hd4_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
# hd4->40*40, hd2->160*160, Upsample 4 times (Using ConvTranspose2d)
self.hd4_UT_hd2 = nn.ConvTranspose2d(
in_channels=self.UpChannels, # 입력 채널
out_channels=self.CatChannels, # 출력 채널
kernel_size=4, # 업샘플링 크기
stride=4 # 4배 업샘플링
)
self.hd4_UT_hd2_conv = nn.Conv2d(self.CatChannels, self.CatChannels, 3, padding=1)
self.hd4_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels)
self.hd4_UT_hd2_relu = nn.ReLU(inplace=True)

# hd5->20*20, hd2->160*160, Upsample 8 times
self.hd5_UT_hd2 = nn.Upsample(scale_factor=8, mode='bilinear') # 14*14
self.hd5_UT_hd2_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)
# hd5->20*20, hd2->160*160, Upsample 8 times (Using ConvTranspose2d)
self.hd5_UT_hd2 = nn.ConvTranspose2d(
in_channels=filters[4], # 입력 채널
out_channels=self.CatChannels, # 출력 채널
kernel_size=8, # 업샘플링 크기
stride=8 # 8배 업샘플링
)
self.hd5_UT_hd2_conv = nn.Conv2d(self.CatChannels, self.CatChannels, 3, padding=1)
self.hd5_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels)
self.hd5_UT_hd2_relu = nn.ReLU(inplace=True)

Expand All @@ -152,33 +184,54 @@ def __init__(self, in_channels=3, n_classes=1, feature_scale=4, is_deconv=True,
self.bn2d_1 = nn.BatchNorm2d(self.UpChannels)
self.relu2d_1 = nn.ReLU(inplace=True)


'''stage 1d'''
# h1->320*320, hd1->320*320, Concatenation
self.h1_Cat_hd1_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)
self.h1_Cat_hd1_bn = nn.BatchNorm2d(self.CatChannels)
self.h1_Cat_hd1_relu = nn.ReLU(inplace=True)

# hd2->160*160, hd1->320*320, Upsample 2 times
self.hd2_UT_hd1 = nn.Upsample(scale_factor=2, mode='bilinear') # 14*14
self.hd2_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
# hd2->160*160, hd1->320*320, Upsample 2 times (Using ConvTranspose2d)
self.hd2_UT_hd1 = nn.ConvTranspose2d(
in_channels=self.UpChannels, # 입력 채널
out_channels=self.CatChannels, # 출력 채널
kernel_size=2, # 업샘플링 크기
stride=2 # 2배 업샘플링
)
self.hd2_UT_hd1_conv = nn.Conv2d(self.CatChannels, self.CatChannels, 3, padding=1)
self.hd2_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)
self.hd2_UT_hd1_relu = nn.ReLU(inplace=True)

# hd3->80*80, hd1->320*320, Upsample 4 times
self.hd3_UT_hd1 = nn.Upsample(scale_factor=4, mode='bilinear') # 14*14
self.hd3_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
# hd3->80*80, hd1->320*320, Upsample 4 times (Using ConvTranspose2d)
self.hd3_UT_hd1 = nn.ConvTranspose2d(
in_channels=self.UpChannels, # 입력 채널
out_channels=self.CatChannels, # 출력 채널
kernel_size=4, # 업샘플링 크기
stride=4 # 4배 업샘플링
)
self.hd3_UT_hd1_conv = nn.Conv2d(self.CatChannels, self.CatChannels, 3, padding=1)
self.hd3_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)
self.hd3_UT_hd1_relu = nn.ReLU(inplace=True)

# hd4->40*40, hd1->320*320, Upsample 8 times
self.hd4_UT_hd1 = nn.Upsample(scale_factor=8, mode='bilinear') # 14*14
self.hd4_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
# hd4->40*40, hd1->320*320, Upsample 8 times (Using ConvTranspose2d)
self.hd4_UT_hd1 = nn.ConvTranspose2d(
in_channels=self.UpChannels, # 입력 채널
out_channels=self.CatChannels, # 출력 채널
kernel_size=8, # 업샘플링 크기
stride=8 # 8배 업샘플링
)
self.hd4_UT_hd1_conv = nn.Conv2d(self.CatChannels, self.CatChannels, 3, padding=1)
self.hd4_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)
self.hd4_UT_hd1_relu = nn.ReLU(inplace=True)

# hd5->20*20, hd1->320*320, Upsample 16 times
self.hd5_UT_hd1 = nn.Upsample(scale_factor=16, mode='bilinear') # 14*14
self.hd5_UT_hd1_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)
# hd5->20*20, hd1->320*320, Upsample 16 times (Using ConvTranspose2d)
self.hd5_UT_hd1 = nn.ConvTranspose2d(
in_channels=filters[4], # 입력 채널
out_channels=self.CatChannels, # 출력 채널
kernel_size=16, # 업샘플링 크기
stride=16 # 16배 업샘플링
)
self.hd5_UT_hd1_conv = nn.Conv2d(self.CatChannels, self.CatChannels, 3, padding=1)
self.hd5_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)
self.hd5_UT_hd1_relu = nn.ReLU(inplace=True)

Expand All @@ -187,19 +240,56 @@ def __init__(self, in_channels=3, n_classes=1, feature_scale=4, is_deconv=True,
self.bn1d_1 = nn.BatchNorm2d(self.UpChannels)
self.relu1d_1 = nn.ReLU(inplace=True)


# -------------Bilinear Upsampling--------------
self.upscore6 = nn.Upsample(scale_factor=32,mode='bilinear')###
self.upscore5 = nn.Upsample(scale_factor=16,mode='bilinear')
self.upscore4 = nn.Upsample(scale_factor=8,mode='bilinear')
self.upscore3 = nn.Upsample(scale_factor=4,mode='bilinear')
self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear')
# -------------Learnable Upsampling using ConvTranspose2d--------------
self.upscore6 = nn.ConvTranspose2d(
in_channels=n_classes, # 입력 채널 수
out_channels=n_classes, # 출력 채널 수 (Segmentation 결과 채널 유지)
kernel_size=64, # 업샘플링 커널 크기
stride=32, # 32배 업샘플링
padding=16 # 출력 크기를 동일하게 맞추기 위한 패딩
)

self.upscore5 = nn.ConvTranspose2d(
in_channels=n_classes,
out_channels=n_classes,
kernel_size=32,
stride=16,
padding=8
)

self.upscore4 = nn.ConvTranspose2d(
in_channels=n_classes,
out_channels=n_classes,
kernel_size=16,
stride=8,
padding=4
)

self.upscore3 = nn.ConvTranspose2d(
in_channels=n_classes,
out_channels=n_classes,
kernel_size=8,
stride=4,
padding=2
)

self.upscore2 = nn.ConvTranspose2d(
in_channels=n_classes,
out_channels=n_classes,
kernel_size=4,
stride=2,
padding=1
)

# DeepSup
self.outconv1 = nn.Conv2d(self.UpChannels, n_classes, 3, padding=1)
self.outconv2 = nn.Conv2d(self.UpChannels, n_classes, 3, padding=1)
self.outconv3 = nn.Conv2d(self.UpChannels, n_classes, 3, padding=1)
self.outconv4 = nn.Conv2d(self.UpChannels, n_classes, 3, padding=1)
self.outconv5 = nn.Conv2d(filters[4], n_classes, 3, padding=1)

self.cls = nn.Sequential(
nn.Dropout(p=0.2), # Dropout으로 오버피팅 방지
nn.Conv2d(filters[4], n_classes, 1), # 클래스 수 반영
Expand All @@ -212,6 +302,8 @@ def __init__(self, in_channels=3, n_classes=1, feature_scale=4, is_deconv=True,
init_weights(m, init_type='kaiming')
elif isinstance(m, nn.BatchNorm2d):
init_weights(m, init_type='kaiming')
elif isinstance(m, nn.ConvTranspose2d):
init_weights(m, init_type='kaiming')


def dotProduct(self,seg,cls):
Expand Down Expand Up @@ -286,15 +378,16 @@ def forward(self, inputs):
d2 = self.upscore2(d2) # 128->256

d1 = self.outconv1(hd1) # 256

'''
d1 = self.dotProduct(d1, cls_branch_mask)
d2 = self.dotProduct(d2, cls_branch_mask)
d3 = self.dotProduct(d3, cls_branch_mask)
d4 = self.dotProduct(d4, cls_branch_mask)
d5 = self.dotProduct(d5, cls_branch_mask)

'''
if self.training:
return d1, d2, d3, d4, d5
else:
#print(d1)
return d1
21 changes: 21 additions & 0 deletions UNet3+/Code/Model/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,28 @@ def forward(self, x):
out = self.relu(out)
return out

class CBAM(nn.Module):
def __init__(self, channels, reduction=16):
super(CBAM, self).__init__()
self.channel_attention = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(channels, channels // reduction, 1, bias=False),
nn.ReLU(inplace=True),
nn.Conv2d(channels // reduction, channels, 1, bias=False),
nn.Sigmoid()
)
self.spatial_attention = nn.Sequential(
nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False),
nn.Sigmoid()
)

def forward(self, x):
# Channel Attention
ca = self.channel_attention(x) * x
# Spatial Attention
sa_input = torch.cat([ca.mean(dim=1, keepdim=True), ca.max(dim=1, keepdim=True)[0]], dim=1)
sa = self.spatial_attention(sa_input)
return sa * ca


class unetConv2(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion UNet3+/Code/Train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def train(model, data_loader, val_loader, criterion, optimizer, scheduler):
best_dice = 0.0

# 손실 가중치 (Deep Supervision)
deep_sup_weights = [0.45, 0.35, 0.25, 0.2, 0.2] # 각 출력에 대한 가중치
deep_sup_weights = [0.5, 0.35, 0.25, 0.2, 0.15] # 각 출력에 대한 가중치

# Mixed Precision Scaler 생성
scaler = GradScaler()
Expand Down
15 changes: 8 additions & 7 deletions UNet3+/Code/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
IMAGE_ROOT = "/data/ephemeral/home/MCG/data/train/DCM"
LABEL_ROOT = "/data/ephemeral/home/MCG/data/train/outputs_json"

'''CLASSES = [
'''
CLASSES = [
'finger-1', 'finger-2', 'finger-3', 'finger-4', 'finger-5',
'finger-6', 'finger-7', 'finger-8', 'finger-9', 'finger-10',
'finger-11', 'finger-12', 'finger-13', 'finger-14', 'finger-15',
Expand Down Expand Up @@ -40,27 +41,27 @@
RANDOM_SEED = 21

# 적절하게 조절
NUM_EPOCHS =75
NUM_EPOCHS =52
VAL_EVERY = 1

BATCH_SIZE = 2
BATCH_SIZE = 4
IMSIZE=480

LR = 0.0003
MILESTONES=[20,37,52,62,67]
MILESTONES=[20,30,37]
GAMMA=0.2


SAVED_DIR = "/data/ephemeral/home/MCG/UNetRefactored/Creadted_model/"
MODELNAME="othersCrop_AddBottleNeck_75.pt"
MODELNAME="othersCrop_AddBottleNeck_ConvTrans_52.pt"
if not os.path.isdir(SAVED_DIR):
os.mkdir(SAVED_DIR)



INFERENCE_MODEL_NAME="othersCrop_AddBottleNeck_75.pt"
INFERENCE_MODEL_NAME="othersCrop_AddBottleNeck_ConvTrans_52.pt"

TEST_IMAGE_ROOT="/data/ephemeral/home/MCG/data/test/DCM"

CSVDIR="/data/ephemeral/home/MCG/UNetRefactored/CSV"
CSVNAME="othersCrop_AddBottleNeck_75.csv"
CSVNAME="othersCrop_AddBottleNeck_ConvTrans_52.csv"

0 comments on commit 51412da

Please sign in to comment.