-
Notifications
You must be signed in to change notification settings - Fork 112
/
Copy pathtrain.py
114 lines (94 loc) · 4.55 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from dataset.dataset_lits_val import Val_Dataset
from dataset.dataset_lits_train import Train_Dataset
from torch.utils.data import DataLoader
import torch
import torch.optim as optim
from tqdm import tqdm
import config
from models import UNet, ResUNet , KiUNet_min, SegNet
from utils import logger, weights_init, metrics, common, loss
import os
import numpy as np
from collections import OrderedDict
def val(model, val_loader, loss_func, n_labels):
model.eval()
val_loss = metrics.LossAverage()
val_dice = metrics.DiceAverage(n_labels)
with torch.no_grad():
for idx,(data, target) in tqdm(enumerate(val_loader),total=len(val_loader)):
data, target = data.float(), target.long()
target = common.to_one_hot_3d(target, n_labels)
data, target = data.to(device), target.to(device)
output = model(data)
loss=loss_func(output, target)
val_loss.update(loss.item(),data.size(0))
val_dice.update(output, target)
val_log = OrderedDict({'Val_Loss': val_loss.avg, 'Val_dice_liver': val_dice.avg[1]})
if n_labels==3: val_log.update({'Val_dice_tumor': val_dice.avg[2]})
return val_log
def train(model, train_loader, optimizer, loss_func, n_labels, alpha):
print("=======Epoch:{}=======lr:{}".format(epoch,optimizer.state_dict()['param_groups'][0]['lr']))
model.train()
train_loss = metrics.LossAverage()
train_dice = metrics.DiceAverage(n_labels)
for idx, (data, target) in tqdm(enumerate(train_loader),total=len(train_loader)):
data, target = data.float(), target.long()
target = common.to_one_hot_3d(target,n_labels)
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss0 = loss_func(output[0], target)
loss1 = loss_func(output[1], target)
loss2 = loss_func(output[2], target)
loss3 = loss_func(output[3], target)
loss = loss3 + alpha * (loss0 + loss1 + loss2)
loss.backward()
optimizer.step()
train_loss.update(loss3.item(),data.size(0))
train_dice.update(output[3], target)
val_log = OrderedDict({'Train_Loss': train_loss.avg, 'Train_dice_liver': train_dice.avg[1]})
if n_labels==3: val_log.update({'Train_dice_tumor': train_dice.avg[2]})
return val_log
if __name__ == '__main__':
args = config.args
save_path = os.path.join('./experiments', args.save)
if not os.path.exists(save_path): os.mkdir(save_path)
device = torch.device('cpu' if args.cpu else 'cuda')
# data info
train_loader = DataLoader(dataset=Train_Dataset(args),batch_size=args.batch_size,num_workers=args.n_threads, shuffle=True)
val_loader = DataLoader(dataset=Val_Dataset(args),batch_size=1,num_workers=args.n_threads, shuffle=False)
# model info
model = ResUNet(in_channel=1, out_channel=args.n_labels,training=True).to(device)
model.apply(weights_init.init_model)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
common.print_network(model)
model = torch.nn.DataParallel(model, device_ids=args.gpu_id) # multi-GPU
loss = loss.TverskyLoss()
log = logger.Train_Logger(save_path,"train_log")
best = [0,0] # 初始化最优模型的epoch和performance
trigger = 0 # early stop 计数器
alpha = 0.4 # 深监督衰减系数初始值
for epoch in range(1, args.epochs + 1):
common.adjust_learning_rate(optimizer, epoch, args)
train_log = train(model, train_loader, optimizer, loss, args.n_labels, alpha)
val_log = val(model, val_loader, loss, args.n_labels)
log.update(epoch,train_log,val_log)
# Save checkpoint.
state = {'net': model.state_dict(),'optimizer':optimizer.state_dict(),'epoch': epoch}
torch.save(state, os.path.join(save_path, 'latest_model.pth'))
trigger += 1
if val_log['Val_dice_liver'] > best[1]:
print('Saving best model')
torch.save(state, os.path.join(save_path, 'best_model.pth'))
best[0] = epoch
best[1] = val_log['Val_dice_liver']
trigger = 0
print('Best performance at Epoch: {} | {}'.format(best[0],best[1]))
# 深监督系数衰减
if epoch % 30 == 0: alpha *= 0.8
# early stopping
if args.early_stop is not None:
if trigger >= args.early_stop:
print("=> early stopping")
break
torch.cuda.empty_cache()