-
Notifications
You must be signed in to change notification settings - Fork 2
/
training.py
140 lines (114 loc) · 6.56 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
__author__ = 'marvinler'
import argparse
import numpy as np
import sklearn.metrics as metrics
import torch
import torch.utils.data
from dataset import Dataset
from model import instantiate_sparseconvmil
def _define_args():
parser = argparse.ArgumentParser(description='SparseConvMIL: Sparse Convolutional Context-Aware Multiple Instance '
'Learning for Whole Slide Image Classification')
parser.add_argument('--slide-parent-folder', type=str, default='sample_data', metavar='PATH',
help='path of parent folder containing preprocessed slides data')
parser.add_argument('--slide-labels-filepath', type=str, default='sample_data/labels.csv', metavar='PATH',
help='path of CSV-file containing slide labels')
parser.add_argument('--epochs', type=int, default=200, metavar='N', help='number of training epochs')
parser.add_argument('--lr', type=float, default=2e-3, metavar='LR', help='learning rate')
parser.add_argument('--reg', type=float, default=1e-6, metavar='R', help='weight decay')
# Model parameters
parser.add_argument('--tile-embedder', type=str, default='resnet18', metavar='MODEL', nargs='*',
help='type of resnet architecture for the tile embedder')
parser.add_argument('--tile-embedder-pretrained', action='store_true', default=False,
help='use Imagenet-pretrained tile embedder architecture')
parser.add_argument('--sparse-conv-n-channels-conv1', type=int, default=32,
help='number of channels of first convolution of the sparse-input CNN pooling')
parser.add_argument('--sparse-conv-n-channels-conv2', type=int, default=32,
help='number of channels of first convolution of the sparse-input CNN pooling')
parser.add_argument('--sparse-map-downsample', type=int, default=10, help='downsampling factor of the sparse map')
parser.add_argument('--wsi-embedding-classifier-n-inner-neurons', type=int, default=32,
help='number of inner neurons for the WSI embedding classifier')
# Dataset parameters
parser.add_argument('--batch-size', type=int, default=2, metavar='SIZE',
help='number of slides sampled per iteration')
parser.add_argument('--n-tiles-per-wsi', type=int, default=5, metavar='SIZE',
help='number of tiles to be sampled per WSI')
# Miscellaneous parameters
parser.add_argument('--j', type=int, default=10, metavar='N_WORKERS', help='number of workers for dataloader')
args = parser.parse_args()
hyper_parameters = {
'slide_parent_folder': args.slide_parent_folder,
'slide_labels_filepath': args.slide_labels_filepath,
'epochs': args.epochs,
'lr': args.lr,
'reg': args.reg,
'tile_embedder': args.tile_embedder,
'tile_embedder_pretrained': args.tile_embedder_pretrained,
'sparse_conv_n_channels_conv1': args.sparse_conv_n_channels_conv1,
'sparse_conv_n_channels_conv2': args.sparse_conv_n_channels_conv2,
'sparse_map_downsample': args.sparse_map_downsample,
'wsi_embedding_classifier_n_inner_neurons': args.wsi_embedding_classifier_n_inner_neurons,
'batch_size': args.batch_size,
'n_tiles_per_wsi': args.n_tiles_per_wsi,
'j': args.j,
}
return hyper_parameters
def get_dataloader(dataset, batch_size, shuffle, num_workers):
return torch.utils.data.DataLoader(dataset, batch_size, shuffle, num_workers=num_workers)
def perform_epoch(mil_model, dataloader, optimizer, loss_function):
"""
Perform a complete training epoch by looping through all data of the dataloader.
:param mil_model: MIL model to be trained
:param dataloader: loader of the dataset
:param optimizer: pytorch optimizer
:param loss_function: loss function to compute gradients
:return: (mean of losses, balanced accuracy)
"""
proba_predictions = []
ground_truths = []
losses = []
for data, locations, slides_labels, slides_ids in dataloader:
data = data.cuda()
locations = locations.cuda()
slides_labels_cuda = slides_labels.cuda()
optimizer.zero_grad()
predictions = mil_model(data, locations)
loss = loss_function(predictions, slides_labels_cuda)
loss.backward()
optimizer.step()
# Store data for finale epoch average measures
losses.append(loss.detach().cpu().numpy())
proba_predictions.extend(predictions.detach().cpu().numpy())
ground_truths.extend(slides_labels.numpy())
predicted_classes = np.argmax(proba_predictions, axis=1)
return np.mean(losses), metrics.balanced_accuracy_score(ground_truths, predicted_classes)
def main(hyper_parameters):
# Loads dataset and dataloader
print('Loading data')
dataset = Dataset(hyper_parameters['slide_parent_folder'], hyper_parameters['slide_labels_filepath'],
hyper_parameters['n_tiles_per_wsi'])
n_classes = dataset.n_classes
dataloader = get_dataloader(dataset, hyper_parameters['batch_size'], True, hyper_parameters['j'])
print(' done')
# Loads MIL model, optimizer and loss function
print('Loading SparseConvMIL model')
sparseconvmil_model = instantiate_sparseconvmil(hyper_parameters['tile_embedder'],
hyper_parameters['tile_embedder_pretrained'],
hyper_parameters['sparse_conv_n_channels_conv1'],
hyper_parameters['sparse_conv_n_channels_conv2'],
3, 3, hyper_parameters['sparse_map_downsample'],
hyper_parameters['wsi_embedding_classifier_n_inner_neurons'],
n_classes)
sparseconvmil_model = torch.nn.DataParallel(sparseconvmil_model)
print(' done')
optimizer = torch.optim.Adam(sparseconvmil_model.parameters(), hyper_parameters['lr'],
weight_decay=hyper_parameters['reg'])
loss_function = torch.nn.CrossEntropyLoss()
# Loop through all epochs
print('Starting training...')
for epoch in range(hyper_parameters["epochs"]):
loss, bac = perform_epoch(sparseconvmil_model, dataloader, optimizer, loss_function)
print('Epoch', f'{epoch:3d}/{hyper_parameters["epochs"]}', f' loss={loss:.3f}', f' bac={bac:.3f}')
print(' done')
if __name__ == '__main__':
main(_define_args())