Skip to content

Commit

Permalink
20211216
Browse files Browse the repository at this point in the history
  • Loading branch information
xiezheng committed Dec 16, 2021
1 parent 03ca9b8 commit 168b269
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 225 deletions.
21 changes: 8 additions & 13 deletions data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,34 +70,29 @@ def __init__(self, src_path1, src_path2, src_path3, patch_size=128, train=True):
self.img_paths = []

files1 = glob.glob(os.path.join(src_path1, '*.png'))
files1.sort()
# files1.sort()
self.len1 = len(files1)
for file_name in files1:
self.img_paths.append(file_name)
self.img_paths.append((file_name, 1))

files2 = glob.glob(os.path.join(src_path2, '*.png'))
files2.sort()
# files2.sort()
self.len2 = len(files2)
for file_name in files2:
self.img_paths.append(file_name)
self.img_paths.append((file_name, 2))

files3 = glob.glob(os.path.join(src_path3, '*.png'))
files3.sort()
# files3.sort()
self.len3 = len(files3)
for file_name in files3:
self.img_paths.append(file_name)
self.img_paths.append((file_name, 3))

self.patch_size = patch_size
self.train = train

def __getitem__(self, index):
if index < self.len1:
label = 1
elif index < self.len1 + self.len2:
label = 2
else:
label = 3
img_array = np.array(Image.open(self.img_paths[index]))
img_path, label = self.img_paths[index]
img_array = np.array(Image.open(img_path))
noisy, clean = np.split(img_array, 2, axis=1)
patch_size = self.patch_size

Expand Down
100 changes: 0 additions & 100 deletions model/DG_UNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,103 +57,3 @@ def forward(self, feature):
adversarial_out = self.ad_net(self.grl_layer(feature))
return adversarial_out


class DG_UNet(nn.Module):
def __init__(self, in_channels=3, out_channels=3):
super(DG_UNet, self).__init__()

# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
self.conv1_1 = nn.Conv2d(in_channels, 32, kernel_size=3, stride=1, padding=1)
self.conv1_2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
self.pool1 = nn.MaxPool2d(kernel_size=2)

self.conv2_1 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.conv2_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
self.pool2 = nn.MaxPool2d(kernel_size=2)

self.conv3_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.conv3_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
self.pool3 = nn.MaxPool2d(kernel_size=2)

self.conv4_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
self.conv4_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.pool4 = nn.MaxPool2d(kernel_size=2)

self.conv5_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)

self.upv6 = nn.ConvTranspose2d(512, 256, 2, stride=2)
self.conv6_1 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1)
self.conv6_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)

self.upv7 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.conv7_1 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1)
self.conv7_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)

self.upv8 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.conv8_1 = nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)
self.conv8_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)

self.upv9 = nn.ConvTranspose2d(64, 32, 2, stride=2)
self.conv9_1 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1)
self.conv9_2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)

self.conv10_1 = nn.Conv2d(32, out_channels, kernel_size=1, stride=1)

def forward(self, x):
conv1 = self.lrelu(self.conv1_1(x))
conv1 = self.lrelu(self.conv1_2(conv1))
pool1 = self.pool1(conv1)

conv2 = self.lrelu(self.conv2_1(pool1))
conv2 = self.lrelu(self.conv2_2(conv2))
pool2 = self.pool1(conv2)

conv3 = self.lrelu(self.conv3_1(pool2))
conv3 = self.lrelu(self.conv3_2(conv3))
pool3 = self.pool1(conv3)

conv4 = self.lrelu(self.conv4_1(pool3))
conv4 = self.lrelu(self.conv4_2(conv4))
pool4 = self.pool1(conv4)

conv5 = self.lrelu(self.conv5_1(pool4))
conv5 = self.lrelu(self.conv5_2(conv5))

up6 = self.upv6(conv5)
up6 = torch.cat([up6, conv4], 1)
conv6 = self.lrelu(self.conv6_1(up6))
conv6 = self.lrelu(self.conv6_2(conv6))

up7 = self.upv7(conv6)
up7 = torch.cat([up7, conv3], 1)
conv7 = self.lrelu(self.conv7_1(up7))
conv7 = self.lrelu(self.conv7_2(conv7))

up8 = self.upv8(conv7)
up8 = torch.cat([up8, conv2], 1)
conv8 = self.lrelu(self.conv8_1(up8))
conv8 = self.lrelu(self.conv8_2(conv8))

up9 = self.upv9(conv8)
up9 = torch.cat([up9, conv1], 1)
conv9 = self.lrelu(self.conv9_1(up9))
conv9 = self.lrelu(self.conv9_2(conv9))

conv10 = self.conv10_1(conv9)
# out = nn.functional.pixel_shuffle(conv10, 2)
out = conv10
return out

def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0.0, 0.02)
if m.bias is not None:
m.bias.data.normal_(0.0, 0.02)
if isinstance(m, nn.ConvTranspose2d):
m.weight.data.normal_(0.0, 0.02)

def lrelu(self, x):
outt = torch.max(0.2 * x, x)
return outt
56 changes: 25 additions & 31 deletions train_multiDataset_ddp_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from warmup_scheduler import GradualWarmupScheduler
from tensorboardX import SummaryWriter
from skimage.measure import compare_psnr

from model.ELD_UNet import ELD_UNet
from utils.util import *
from model.DG_UNet import *
from loss.ad_loss import *
Expand All @@ -30,8 +32,7 @@ def train(opt, epoch, model, data_loader, optimizer, scheduler, criterion, logge
for iteration, batch in enumerate(data_loader):
# load data
(noisy, target, label) = batch
noisy, target = noisy.cuda(opt.local_rank, non_blocking=True), target.cuda(opt.local_rank,
non_blocking=True)
noisy, target = noisy.cuda(opt.local_rank, non_blocking=True), target.cuda(opt.local_rank, non_blocking=True)

# forward
prediction = model(noisy)
Expand All @@ -47,21 +48,22 @@ def train(opt, epoch, model, data_loader, optimizer, scheduler, criterion, logge
dist.barrier()
reduced_l1_loss = reduce_mean(l1_loss, opt.nProcs)
reduced_total_loss = reduce_mean(total_loss, opt.nProcs)

epoch_l1_loss.update(reduced_l1_loss.item(), noisy.size(0))
epoch_total_loss.update(reduced_total_loss.item(), noisy.size(0))

if iteration % opt.print_freq == 0:
ddp_logger_info(
'Train epoch: [{:d}/{:d}]\titeration: [{:d}/{:d}]\tlr={:.6f}\tl1_loss={:.4f}\ttotal_loss={:.4f}'
.format(epoch, opt.nEpochs, iteration, len(data_loader), scheduler.get_lr()[0], epoch_l1_loss.avg, epoch_total_loss.avg),
logger, opt.local_rank)
.format(epoch, opt.nEpochs, iteration, len(data_loader), scheduler.get_lr()[0], epoch_l1_loss.avg,
epoch_total_loss.avg), logger, opt.local_rank)

ddp_writer_add_scalar('Train_L1_loss', epoch_l1_loss.avg, epoch, writer, opt.local_rank)
ddp_writer_add_scalar('Train_total_loss', epoch_total_loss.avg, epoch, writer, opt.local_rank)
ddp_writer_add_scalar('Learning_rate', scheduler.get_lr()[0], epoch, writer, opt.local_rank)
ddp_logger_info(
'||==> Train epoch: [{:d}/{:d}]\tlr={:.6f}\tl1_loss={:.4f}\ttotal_loss={:.4f}\tcost_time={:.4f}'
.format(epoch, opt.nEpochs, scheduler.get_lr()[0], epoch_l1_loss.avg,
epoch_total_loss.avg, time.time() - t0),
.format(epoch, opt.nEpochs, scheduler.get_lr()[0], epoch_l1_loss.avg, epoch_total_loss.avg, time.time() - t0),
logger, opt.local_rank)


Expand All @@ -79,7 +81,6 @@ def valid(opt, epoch, data_loader, model, criterion, logger, writer):
prediction = torch.clamp(prediction, 0.0, 1.0)

loss = criterion(prediction, target)

prediction = prediction.data.cpu().numpy().astype(np.float32)
target = target.data.cpu().numpy().astype(np.float32)
for i in range(prediction.shape[0]):
Expand All @@ -88,6 +89,7 @@ def valid(opt, epoch, data_loader, model, criterion, logger, writer):
dist.barrier()
reduced_psnr = reduce_mean(torch.Tensor([psnr.avg]).cuda(opt.local_rank, non_blocking=True), opt.nProcs)
reduced_loss = reduce_mean(loss, opt.nProcs)

psnr_val.update(reduced_psnr.item(), prediction.shape[0])
loss_val.update(reduced_loss.item(), prediction.shape[0])

Expand Down Expand Up @@ -117,7 +119,7 @@ def main():
parser.add_argument('--nEpochs', type=int, default=150, help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=2e-4, help='learning rate. default=0.0002')
parser.add_argument('--lr_min', type=float, default=1e-5, help='minimum learning rate. default=0.000001')
parser.add_argument('--start_iter', type=int, default=1, help='starting epoch')
parser.add_argument('--start_epoch', type=int, default=1, help='starting epoch')
parser.add_argument('--weight_decay', type=float, default=1e-8, help='weight_decay')

# model settings
Expand All @@ -134,7 +136,6 @@ def main():

# distributed
parser.add_argument('--local_rank', default=0, type=int, help='node rank for distributed training')

opt = parser.parse_args()

# initialize
Expand Down Expand Up @@ -168,11 +169,10 @@ def main():
torch.cuda.set_device(device=opt.local_rank)

# load dataset
ddp_logger_info('Loading datasets {}, {}, {}, Batch Size: {}, Patch Size: {}'.format(opt.data_set1, opt.data_set2,
opt.data_set3,
opt.batch_size,
opt.patch_size), logger,
opt.local_rank)
ddp_logger_info('Loading datasets {}, {}, {}, Batch Size: {}, Patch Size: {}'
.format(opt.data_set1, opt.data_set2, opt.data_set3, opt.batch_size, opt.patch_size),
logger, opt.local_rank)

train_set = LoadMultiDataset(src_path1=os.path.join(opt.data_dir, opt.data_set1, 'train'),
src_path2=os.path.join(opt.data_dir, opt.data_set2, 'train'),
src_path3=os.path.join(opt.data_dir, opt.data_set3, 'train'),
Expand All @@ -181,9 +181,8 @@ def main():
train_sampler = DistributedSampler(train_set)
train_data_loader = DataLoaderX(dataset=train_set, batch_size=opt.batch_size,
num_workers=opt.num_workers, pin_memory=True, sampler=train_sampler)
ddp_logger_info(
'Train dataset length: {} 1:{} 2:{} 3:{}'.format(len(train_data_loader), train_set.len1, train_set.len2,
train_set.len3), logger, opt.local_rank)
ddp_logger_info('Train dataset length: {} 1:{} 2:{} 3:{}'.format(len(train_data_loader), train_set.len1,
train_set.len2, train_set.len3), logger, opt.local_rank)

val_set = LoadDataset(src_path=os.path.join(opt.data_dir, opt.data_set_test, 'test'),
patch_size=opt.test_patch_size,
Expand All @@ -195,8 +194,7 @@ def main():

# load network
ddp_logger_info('Building model {}'.format(opt.model_type), logger, opt.local_rank)
model = DG_UNet()

model = ELD_UNet()
ddp_logger_info("Push model to distribute data parallel!", logger, opt.local_rank)
model.cuda(device=opt.local_rank)
ddp_logger_info('model={}'.format(model), logger, opt.local_rank)
Expand All @@ -221,16 +219,15 @@ def main():

# resume
if opt.pretrain_model != '':
model, at_net, start_epoch, optimizer, psnr_best = load_model(opt.pretrain_model, model,
optimizer, logger)
model, start_epoch, optimizer, psnr_best = load_model(opt.pretrain_model, model,
optimizer, logger, opt.local_rank)
start_epoch += 1
for i in range(1, start_epoch):
scheduler.step()
ddp_logger_info('Resume start epoch: {}, Learning rate:{:.6f}'.format(start_epoch, scheduler.get_lr()[
0]),
ddp_logger_info('Resume start epoch: {}, Learning rate:{:.6f}'.format(start_epoch, scheduler.get_lr()[0]),
logger, opt.local_rank)
else:
start_epoch = opt.start_iter
start_epoch = opt.start_epoch
ddp_logger_info(
'Start epoch: {}, Learning rate:{:.6f}'.format(start_epoch, scheduler.get_lr()[0]), logger,
opt.local_rank)
Expand All @@ -242,20 +239,17 @@ def main():
scheduler.step()

# training
train(opt, epoch, model, train_data_loader, optimizer,
scheduler,
criterion, logger, writer)
train(opt, epoch, model, train_data_loader, optimizer, scheduler, criterion, logger, writer)
# validation
psnr = valid(opt, epoch, val_data_loader, model, criterion, logger, writer)

# save model
if opt.local_rank == 0:
if psnr > psnr_best:
if psnr >= psnr_best:
psnr_best = psnr
epoch_best = epoch
save_model(os.path.join(checkpoint_folder, "model_best.pth"), epoch, model, optimizer, psnr_best,
logger)
save_model(os.path.join(checkpoint_folder, "model_latest.pth"), epoch, model, optimizer, psnr_best,logger)
save_model(os.path.join(checkpoint_folder, "model_best.pth"), epoch, model, optimizer, psnr_best, logger)
save_model(os.path.join(checkpoint_folder, "model_latest.pth"), epoch, model, optimizer, psnr_best, logger)

ddp_logger_info('||==> best_epoch = {}, best_psnr = {}'.format(epoch_best, psnr_best), logger, opt.local_rank)

Expand Down
Loading

0 comments on commit 168b269

Please sign in to comment.